Browse Source

ESQL - KNN function uses LIMIT for K, transforms to exact search when not pushed down (#132944)

Carlos Delgado 1 tháng trước cách đây
mục cha
commit
e76c2a6c57
25 tập tin đã thay đổi với 572 bổ sung230 xóa
  1. 1 1
      docs/reference/query-languages/esql/_snippets/functions/examples/knn.md
  2. 3 3
      docs/reference/query-languages/esql/_snippets/functions/functionNamedParams/knn.md
  3. 0 3
      docs/reference/query-languages/esql/_snippets/functions/parameters/knn.md
  4. 1 1
      docs/reference/query-languages/esql/images/functions/knn.svg
  5. 1 1
      docs/reference/query-languages/esql/kibana/definition/functions/knn.json
  6. 1 1
      docs/reference/query-languages/esql/kibana/docs/functions/knn.md
  7. 10 4
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluator.java
  8. 3 4
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ScoreOperator.java
  9. 87 55
      x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec
  10. 10 7
      x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java
  11. 1 1
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
  12. 1 1
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java
  13. 13 2
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java
  14. 68 61
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java
  15. 1 1
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java
  16. 2 0
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java
  17. 69 0
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushLimitToKnn.java
  18. 17 7
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java
  19. 1 1
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java
  20. 2 3
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java
  21. 16 17
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java
  22. 2 2
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java
  23. 2 1
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/AbstractLogicalPlanOptimizerTests.java
  24. 98 41
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java
  25. 162 12
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java

+ 1 - 1
docs/reference/query-languages/esql/_snippets/functions/examples/knn.md

@@ -4,7 +4,7 @@
 
 ```esql
 from colors metadata _score
-| where knn(rgb_vector, [0, 120, 0], 10)
+| where knn(rgb_vector, [0, 120, 0])
 | sort _score desc, color asc
 ```
 

+ 3 - 3
docs/reference/query-languages/esql/_snippets/functions/functionNamedParams/knn.md

@@ -2,12 +2,12 @@
 
 **Supported function named parameters**
 
-`num_candidates`
-:   (integer) The number of nearest neighbor candidates to consider per shard while doing knn search. Cannot exceed 10,000. Increasing num_candidates tends to improve the accuracy of the final results. Defaults to 1.5 * k
-
 `boost`
 :   (float) Floating point number used to decrease or increase the relevance scores of the query.Defaults to 1.0.
 
+`min_candidates`
+:   (integer) The minimum number of nearest neighbor candidates to consider per shard while doing knn search.  KNN may use a higher number of candidates in case the query can't use a approximate results. Cannot exceed 10,000. Increasing min_candidates tends to improve the accuracy of the final results. Defaults to 1.5 * LIMIT used for the query.
+
 `rescore_oversample`
 :   (double) Applies the specified oversampling for rescoring quantized vectors. See [oversampling and rescoring quantized vectors](docs-content://solutions/search/vector/knn.md#dense-vector-knn-search-rescoring) for details.
 

+ 0 - 3
docs/reference/query-languages/esql/_snippets/functions/parameters/knn.md

@@ -8,9 +8,6 @@
 `query`
 :   Vector value to find top nearest neighbours for.
 
-`k`
-:   The number of nearest neighbors to return from each shard. Elasticsearch collects k results from each shard, then merges them to find the global top results. This value must be less than or equal to num_candidates.
-
 `options`
 :   (Optional) kNN additional options as [function named parameters](/reference/query-languages/esql/esql-syntax.md#esql-function-named-params). See [knn query](/reference/query-languages/query-dsl/query-dsl-match-query.md#query-dsl-knn-query) for more information.
 

+ 1 - 1
docs/reference/query-languages/esql/images/functions/knn.svg

@@ -1 +1 @@
-<svg version="1.1" xmlns:xlink="http://www.w3.org/1999/xlink" xmlns="http://www.w3.org/2000/svg" width="652" height="61" viewbox="0 0 652 61"><defs><style type="text/css">.c{fill:none;stroke:#222222;}.k{fill:#000000;font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;font-size:20px;}.s{fill:#e4f4ff;stroke:#222222;}.syn{fill:#8D8D8D;font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;font-size:20px;}</style></defs><path class="c" d="M0 31h5m56 0h10m32 0h10m80 0h10m32 0h10m80 0h10m32 0h10m32 0h10m32 0h30m104 0h20m-139 0q5 0 5 5v10q0 5 5 5h114q5 0 5-5v-10q0-5 5-5m5 0h10m32 0h5"/><rect class="s" x="5" y="5" width="56" height="36"/><text class="k" x="15" y="31">KNN</text><rect class="s" x="71" y="5" width="32" height="36" rx="7"/><text class="syn" x="81" y="31">(</text><rect class="s" x="113" y="5" width="80" height="36" rx="7"/><text class="k" x="123" y="31">field</text><rect class="s" x="203" y="5" width="32" height="36" rx="7"/><text class="syn" x="213" y="31">,</text><rect class="s" x="245" y="5" width="80" height="36" rx="7"/><text class="k" x="255" y="31">query</text><rect class="s" x="335" y="5" width="32" height="36" rx="7"/><text class="syn" x="345" y="31">,</text><rect class="s" x="377" y="5" width="32" height="36" rx="7"/><text class="k" x="387" y="31">k</text><rect class="s" x="419" y="5" width="32" height="36" rx="7"/><text class="syn" x="429" y="31">,</text><rect class="s" x="481" y="5" width="104" height="36" rx="7"/><text class="k" x="491" y="31">options</text><rect class="s" x="615" y="5" width="32" height="36" rx="7"/><text class="syn" x="625" y="31">)</text></svg>
+<svg version="1.1" xmlns:xlink="http://www.w3.org/1999/xlink" xmlns="http://www.w3.org/2000/svg" width="568" height="61" viewbox="0 0 568 61"><defs><style type="text/css">.c{fill:none;stroke:#222222;}.k{fill:#000000;font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;font-size:20px;}.s{fill:#e4f4ff;stroke:#222222;}.syn{fill:#8D8D8D;font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;font-size:20px;}</style></defs><path class="c" d="M0 31h5m56 0h10m32 0h10m80 0h10m32 0h10m80 0h10m32 0h30m104 0h20m-139 0q5 0 5 5v10q0 5 5 5h114q5 0 5-5v-10q0-5 5-5m5 0h10m32 0h5"/><rect class="s" x="5" y="5" width="56" height="36"/><text class="k" x="15" y="31">KNN</text><rect class="s" x="71" y="5" width="32" height="36" rx="7"/><text class="syn" x="81" y="31">(</text><rect class="s" x="113" y="5" width="80" height="36" rx="7"/><text class="k" x="123" y="31">field</text><rect class="s" x="203" y="5" width="32" height="36" rx="7"/><text class="syn" x="213" y="31">,</text><rect class="s" x="245" y="5" width="80" height="36" rx="7"/><text class="k" x="255" y="31">query</text><rect class="s" x="335" y="5" width="32" height="36" rx="7"/><text class="syn" x="345" y="31">,</text><rect class="s" x="397" y="5" width="104" height="36" rx="7"/><text class="k" x="407" y="31">options</text><rect class="s" x="531" y="5" width="32" height="36" rx="7"/><text class="syn" x="541" y="31">)</text></svg>

+ 1 - 1
docs/reference/query-languages/esql/kibana/definition/functions/knn.json

@@ -5,7 +5,7 @@
   "description" : "Finds the k nearest vectors to a query vector, as measured by a similarity metric. knn function finds nearest vectors through approximate search on indexed dense_vectors.",
   "signatures" : [ ],
   "examples" : [
-    "from colors metadata _score\n| where knn(rgb_vector, [0, 120, 0], 10)\n| sort _score desc, color asc"
+    "from colors metadata _score\n| where knn(rgb_vector, [0, 120, 0])\n| sort _score desc, color asc"
   ],
   "preview" : true,
   "snapshot_only" : true

+ 1 - 1
docs/reference/query-languages/esql/kibana/docs/functions/knn.md

@@ -5,6 +5,6 @@ Finds the k nearest vectors to a query vector, as measured by a similarity metri
 
 ```esql
 from colors metadata _score
-| where knn(rgb_vector, [0, 120, 0], 10)
+| where knn(rgb_vector, [0, 120, 0])
 | sort _score desc, color asc
 ```

+ 10 - 4
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluator.java

@@ -60,10 +60,16 @@ public abstract class LuceneQueryEvaluator<T extends Block.Builder> implements R
     }
 
     public Block executeQuery(Page page) {
-        // Lucene based operators retrieve DocVectors as first block
-        Block block = page.getBlock(0);
-        assert block instanceof DocBlock : "LuceneQueryExpressionEvaluator expects DocBlock as input";
-        DocVector docs = (DocVector) block.asVector();
+        // Search for DocVector block
+        Block docBlock = null;
+        for (int i = 0; i < page.getBlockCount(); i++) {
+            if (page.getBlock(i) instanceof DocBlock) {
+                docBlock = page.getBlock(i);
+                break;
+            }
+        }
+        assert docBlock != null : "LuceneQueryExpressionEvaluator expects a DocBlock";
+        DocVector docs = (DocVector) docBlock.asVector();
         try {
             if (docs.singleSegmentNonDecreasing()) {
                 return evalSingleSegmentNonDecreasing(docs);

+ 3 - 4
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ScoreOperator.java

@@ -9,7 +9,6 @@ package org.elasticsearch.compute.operator;
 
 import org.elasticsearch.compute.data.Block;
 import org.elasticsearch.compute.data.BlockFactory;
-import org.elasticsearch.compute.data.DocVector;
 import org.elasticsearch.compute.data.DoubleBlock;
 import org.elasticsearch.compute.data.DoubleVector;
 import org.elasticsearch.compute.data.Page;
@@ -46,9 +45,9 @@ public class ScoreOperator extends AbstractPageMappingOperator {
 
     @Override
     protected Page process(Page page) {
-        assert page.getBlockCount() >= 2 : "Expected at least 2 blocks, got " + page.getBlockCount();
-        assert page.getBlock(0).asVector() instanceof DocVector : "Expected a DocVector, got " + page.getBlock(0).asVector();
-        assert page.getBlock(1).asVector() instanceof DoubleVector : "Expected a DoubleVector, got " + page.getBlock(1).asVector();
+        assert page.getBlockCount() > scoreBlockPosition : "Expected to get a score block in position " + scoreBlockPosition;
+        assert page.getBlock(scoreBlockPosition).asVector() instanceof DoubleVector
+            : "Expected a DoubleVector as a score block, got " + page.getBlock(scoreBlockPosition).asVector();
 
         Block[] blocks = new Block[page.getBlockCount()];
         for (int i = 0; i < page.getBlockCount(); i++) {

+ 87 - 55
x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec

@@ -3,11 +3,11 @@
 # top-n query at the shard level 
 
 knnSearch
-required_capability: knn_function_v3
+required_capability: knn_function_v4
 
 // tag::knn-function[]
 from colors metadata _score 
-| where knn(rgb_vector, [0, 120, 0], 10) 
+| where knn(rgb_vector, [0, 120, 0]) 
 | sort _score desc, color asc
 // end::knn-function[]
 | keep color, rgb_vector
@@ -30,10 +30,10 @@ chartreuse | [127.0, 255.0, 0.0]
 ;
 
 knnSearchWithSimilarityOption
-required_capability: knn_function_v3
+required_capability: knn_function_v4
 
 from colors metadata _score 
-| where knn(rgb_vector, [255,192,203], 140, {"similarity": 40})
+| where knn(rgb_vector, [255,192,203], {"similarity": 40})
 | sort _score desc, color asc
 | keep color, rgb_vector
 ;
@@ -46,13 +46,14 @@ wheat      | [245.0, 222.0, 179.0]
 ;
 
 knnHybridSearch
-required_capability: knn_function_v3
+required_capability: knn_function_v4
 
 from colors metadata _score 
-| where match(color, "blue") or knn(rgb_vector, [65,105,225], 10)
+| where match(color, "blue") or knn(rgb_vector, [65,105,225])
 | where primary == true
 | sort _score desc, color asc
 | keep color, rgb_vector
+| limit 10
 ;
 
 color:text | rgb_vector:dense_vector
@@ -68,10 +69,10 @@ yellow     | [255.0, 255.0, 0.0]
 ;
 
 knnWithPrefilter
-required_capability: knn_function_v3
+required_capability: knn_function_v4
 
 from colors
-| where knn(rgb_vector, [120,180,0], 10) and (match(color, "olive") or match(color, "green")) 
+| where knn(rgb_vector, [120,180,0]) and (match(color, "olive") or match(color, "green")) 
 | sort color asc
 | keep color
 ;
@@ -82,10 +83,10 @@ olive
 ;
 
 knnWithNegatedPrefilter
-required_capability: knn_function_v3
+required_capability: knn_function_v4
 
 from colors metadata _score
-| where knn(rgb_vector, [128,128,0], 10) and not (match(color, "olive") or match(color, "chocolate")) 
+| where knn(rgb_vector, [128,128,0]) and not (match(color, "olive") or match(color, "chocolate")) 
 | sort _score desc, color asc
 | keep color, rgb_vector
 | LIMIT 10
@@ -105,11 +106,11 @@ orange     | [255.0, 165.0, 0.0]
 ;
 
 knnAfterKeep
-required_capability: knn_function_v3
+required_capability: knn_function_v4
 
 from colors metadata _score
 | keep rgb_vector, color, _score 
-| where knn(rgb_vector, [128,255,0], 140)
+| where knn(rgb_vector, [128,255,0])
 | sort _score desc, color asc
 | keep rgb_vector
 | limit 5
@@ -124,11 +125,11 @@ rgb_vector:dense_vector
 ;
 
 knnAfterDrop
-required_capability: knn_function_v3
+required_capability: knn_function_v4
 
 from colors metadata _score
 | drop primary
-| where knn(rgb_vector, [128,250,0], 140)
+| where knn(rgb_vector, [128,250,0])
 | sort _score desc, color asc
 | keep color, rgb_vector
 | limit 5
@@ -143,11 +144,11 @@ lime           | [0.0, 255.0, 0.0]
 ;
 
 knnAfterEval
-required_capability: knn_function_v3
+required_capability: knn_function_v4
 
 from colors metadata _score
 | eval composed_name = locate(color, " ") > 0 
-| where knn(rgb_vector, [128,128,0], 140)
+| where knn(rgb_vector, [128,128,0])
 | sort _score desc, color asc
 | keep color, composed_name 
 | limit 5
@@ -162,12 +163,13 @@ golden rod | true
 ;
 
 knnWithConjunction
-required_capability: knn_function_v3
+required_capability: knn_function_v4
 
 from colors metadata _score 
-| where knn(rgb_vector, [255,255,238], 10) and hex_code like "#FFF*" 
+| where knn(rgb_vector, [255,255,238]) and hex_code like "#FFF*" 
 | sort _score desc, color asc
 | keep color, hex_code, rgb_vector
+| limit 10
 ;
 
 color:text    | hex_code:keyword | rgb_vector:dense_vector
@@ -181,10 +183,10 @@ yellow        | #FFFF00          | [255.0, 255.0, 0.0]
 ;
 
 knnWithDisjunctionAndFiltersConjunction
-required_capability: knn_function_v3
+required_capability: knn_function_v4
 
 from colors metadata _score 
-| where (knn(rgb_vector, [0,255,255], 140) or knn(rgb_vector, [128, 0, 255], 10)) and primary == true 
+| where (knn(rgb_vector, [0,255,255]) or knn(rgb_vector, [128, 0, 255])) and primary == true 
 | keep color, rgb_vector, _score
 | sort _score desc, color asc
 | drop _score
@@ -204,10 +206,10 @@ yellow     | [255.0, 255.0, 0.0]
 ;
 
 knnWithNegationsAndFiltersConjunction
-required_capability: knn_function_v3
+required_capability: knn_function_v4
 
 from colors metadata _score 
-| where (knn(rgb_vector, [0,255,255], 140) and not(primary == true and match(color, "blue"))) 
+| where (knn(rgb_vector, [0,255,255]) and not(primary == true and match(color, "blue"))) 
 | sort _score desc, color asc
 | keep color, rgb_vector
 | limit 10
@@ -227,11 +229,11 @@ azure       | [240.0, 255.0, 255.0]
 ;
 
 knnWithNonPushableConjunction
-required_capability: knn_function_v3
+required_capability: knn_function_v4
 
 from colors metadata _score
 | eval composed_name = locate(color, " ") > 0 
-| where knn(rgb_vector, [128,128,0], 140) and composed_name == false
+| where knn(rgb_vector, [128,128,0], {"min_candidates": 100}) and composed_name == false
 | sort _score desc, color asc
 | keep color, composed_name
 | limit 10
@@ -251,58 +253,88 @@ maroon     | false
 ;
 
 testKnnWithNonPushableDisjunctions
-required_capability: knn_function_v3
+required_capability: knn_function_v4
 
 from colors metadata _score 
-| where knn(rgb_vector, [128,128,0], 140, {"similarity": 30}) or length(color) > 10 
+| where knn(rgb_vector, [128,128,0]) or length(color) > 10 
 | sort _score desc, color asc
-| keep color 
+| keep color
+| limit 10
 ;
 
 color:text
-olive
-aqua marine
-lemon chiffon
-papaya whip
+olive          
+sienna         
+chocolate      
+peru           
+golden rod     
+brown          
+firebrick      
+chartreuse     
+gray           
+green   
 ;
 
-testKnnWithNonPushableDisjunctionsOnComplexExpressions
-required_capability: knn_function_v3
+testKnnWithNonPushableDisjunctionsAndMinCandidates
+required_capability: knn_function_v4
 
 from colors metadata _score 
-| where (knn(rgb_vector, [128,128,0], 140, {"similarity": 70}) and length(color) < 10) or (knn(rgb_vector, [128,0,128], 140, {"similarity": 60}) and primary == false) 
+| where (knn(rgb_vector, [128,128,0], {"min_candidates": 2}) and length(color) > 10) or (knn(rgb_vector, [128,0,128], {"min_candidates": 2}) and primary == true) 
 | sort _score desc, color asc
 | keep color, primary
 ;
 
 color:text   | primary:boolean
-olive        | false
-purple       | false
-indigo       | false
-;
+gray          | true
+green         | true
+red           | true
+black         | true
+magenta       | true
+yellow        | true
+blue          | true
+aqua marine   | false
+papaya whip   | false
+lemon chiffon | false
+white         | true
+cyan          | true
+;
+
+testKnnWithStats
+required_capability: knn_function_v4
 
-testKnnInStatsNonPushable
-required_capability: knn_function_v3
-
-from colors 
-| where length(color) < 10 
-| stats c = count(*) where knn(rgb_vector, [128,128,255], 140)
+from colors metadata _score 
+| where knn(rgb_vector, [128,128,0])
+| sort _score desc, color asc
+| limit 15
+| stats c = count(*)
 ;
 
-c: long 
-50      
+c:long
+15
 ;
 
-testKnnInStatsWithGrouping
-required_capability: knn_function_v3
-required_capability: full_text_functions_in_stats_where
+testKnnWithRerank
+required_capability: knn_function_v4
+required_capability: rerank
 
-from colors 
-| where length(color) < 10 
-| stats c = count(*) where knn(rgb_vector, [128,128,255], 140) by primary
+from colors metadata _score 
+| where knn(rgb_vector, [100,120,0])
+| sort _score desc, color asc
+| limit 10
+| rerank rerank_score = "deepest blue" ON color WITH { "inference_id" : "test_reranker" }
+| sort rerank_score desc, color asc
+| keep color
 ;
 
-c: long       | primary: boolean    
-41            | false          
-9             | true           
+color:text
+gray
+peru
+brown
+green
+olive
+maroon
+sienna
+chocolate
+firebrick
+golden rod
 ;

+ 10 - 7
x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java

@@ -74,9 +74,10 @@ public class KnnFunctionIT extends AbstractEsqlIntegTestCase {
 
         var query = String.format(Locale.ROOT, """
             FROM test METADATA _score
-            | WHERE knn(vector, %s, 10)
+            | WHERE knn(vector, %s)
             | KEEP id, _score, vector
             | SORT _score DESC
+            | LIMIT 10
             """, Arrays.toString(queryVector));
 
         try (var resp = run(query)) {
@@ -113,9 +114,10 @@ public class KnnFunctionIT extends AbstractEsqlIntegTestCase {
 
         var query = String.format(Locale.ROOT, """
             FROM test METADATA _score
-            | WHERE knn(vector, %s, 5)
+            | WHERE knn(vector, %s)
             | KEEP id, _score, vector
             | SORT _score DESC
+            | LIMIT 5
             """, Arrays.toString(queryVector));
 
         try (var resp = run(query)) {
@@ -131,12 +133,12 @@ public class KnnFunctionIT extends AbstractEsqlIntegTestCase {
         float[] queryVector = new float[numDims];
         Arrays.fill(queryVector, 0.0f);
 
-        // TODO we need to decide what to do when / if user uses k for limit, as no more than k results will be returned from knn query
         var query = String.format(Locale.ROOT, """
             FROM test METADATA _score
-            | WHERE knn(vector, %s, 5) OR id > 100
+            | WHERE knn(vector, %s) OR id > 100
             | KEEP id, _score, vector
             | SORT _score DESC
+            | LIMIT 5
             """, Arrays.toString(queryVector));
 
         try (var resp = run(query)) {
@@ -155,7 +157,7 @@ public class KnnFunctionIT extends AbstractEsqlIntegTestCase {
         // We retrieve 5 from knn, but must be prefiltered with id > 5 or no result will be returned as it would be post-filtered
         var query = String.format(Locale.ROOT, """
             FROM test METADATA _score
-            | WHERE knn(vector, %s, 5) AND id > 5 AND id <= 10
+            | WHERE knn(vector, %s) AND id > 5 AND id <= 10
             | KEEP id, _score, vector
             | SORT _score DESC
             | LIMIT 5
@@ -178,7 +180,8 @@ public class KnnFunctionIT extends AbstractEsqlIntegTestCase {
         var query = String.format(Locale.ROOT, """
             FROM test
             | LOOKUP JOIN test_lookup ON id
-            | WHERE KNN(lookup_vector, %s, 5) OR id > 100
+            | WHERE KNN(lookup_vector, %s) OR id > 100
+            | LIMIT 5
             """, Arrays.toString(queryVector));
 
         var error = expectThrows(VerificationException.class, () -> run(query));
@@ -193,7 +196,7 @@ public class KnnFunctionIT extends AbstractEsqlIntegTestCase {
 
     @Before
     public void setup() throws IOException {
-        assumeTrue("Needs KNN support", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
+        assumeTrue("Needs KNN support", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled());
 
         var indexName = "test";
         var client = client().admin().indices();

+ 1 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java

@@ -1291,7 +1291,7 @@ public class EsqlCapabilities {
         /**
          * Support knn function
          */
-        KNN_FUNCTION_V3(Build.current().isSnapshot()),
+        KNN_FUNCTION_V4(Build.current().isSnapshot()),
 
         /**
          * Support for the LIKE operator with a list of wildcards.

+ 1 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java

@@ -505,7 +505,7 @@ public class EsqlFunctionRegistry {
                 def(FirstOverTime.class, uni(FirstOverTime::new), "first_over_time"),
                 def(Score.class, uni(Score::new), Score.NAME),
                 def(Term.class, bi(Term::new), "term"),
-                def(Knn.class, quad(Knn::new), "knn"),
+                def(Knn.class, tri(Knn::new), "knn"),
                 def(ToGeohash.class, ToGeohash::new, "to_geohash"),
                 def(ToGeotile.class, ToGeotile::new, "to_geotile"),
                 def(ToGeohex.class, ToGeohex::new, "to_geohex"),

+ 13 - 2
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java

@@ -384,18 +384,29 @@ public abstract class FullTextFunction extends Function
         ShardConfig[] shardConfigs = new ShardConfig[shardContexts.size()];
         int i = 0;
         for (EsPhysicalOperationProviders.ShardContext shardContext : shardContexts) {
-            shardConfigs[i++] = new ShardConfig(shardContext.toQuery(queryBuilder()), shardContext.searcher());
+            shardConfigs[i++] = new ShardConfig(shardContext.toQuery(evaluatorQueryBuilder()), shardContext.searcher());
         }
         return new LuceneQueryExpressionEvaluator.Factory(shardConfigs);
     }
 
+    /**
+     * Returns the query builder to be used when the function cannot be pushed down to Lucene, but uses a
+     * {@link org.elasticsearch.compute.lucene.LuceneQueryEvaluator} instead
+     *
+     * @return the query builder to be used in the {@link org.elasticsearch.compute.lucene.LuceneQueryEvaluator}
+     */
+    protected QueryBuilder evaluatorQueryBuilder() {
+        // Use the same query builder as for the translation by default
+        return queryBuilder();
+    }
+
     @Override
     public ScoreOperator.ExpressionScorer.Factory toScorer(ToScorer toScorer) {
         List<EsPhysicalOperationProviders.ShardContext> shardContexts = toScorer.shardContexts();
         ShardConfig[] shardConfigs = new ShardConfig[shardContexts.size()];
         int i = 0;
         for (EsPhysicalOperationProviders.ShardContext shardContext : shardContexts) {
-            shardConfigs[i++] = new ShardConfig(shardContext.toQuery(queryBuilder()), shardContext.searcher());
+            shardConfigs[i++] = new ShardConfig(shardContext.toQuery(evaluatorQueryBuilder()), shardContext.searcher());
         }
         return new LuceneQueryScoreEvaluator.Factory(shardConfigs);
     }

+ 68 - 61
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java

@@ -7,15 +7,17 @@
 
 package org.elasticsearch.xpack.esql.expression.function.vector;
 
-import org.apache.logging.log4j.LogManager;
-import org.apache.logging.log4j.Logger;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.search.vectors.ExactKnnQueryBuilder;
+import org.elasticsearch.search.vectors.VectorData;
 import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
 import org.elasticsearch.xpack.esql.capabilities.PostAnalysisPlanVerificationAware;
+import org.elasticsearch.xpack.esql.capabilities.PostOptimizationVerificationAware;
 import org.elasticsearch.xpack.esql.capabilities.TranslationAware;
+import org.elasticsearch.xpack.esql.common.Failure;
 import org.elasticsearch.xpack.esql.common.Failures;
 import org.elasticsearch.xpack.esql.core.InvalidArgumentException;
 import org.elasticsearch.xpack.esql.core.expression.Expression;
@@ -54,14 +56,11 @@ import java.util.function.BiConsumer;
 import static java.util.Map.entry;
 import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
 import static org.elasticsearch.index.query.AbstractQueryBuilder.BOOST_FIELD;
-import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.K_FIELD;
-import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.NUM_CANDS_FIELD;
 import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.VECTOR_SIMILARITY_FIELD;
 import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
 import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FOURTH;
 import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND;
 import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.THIRD;
-import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable;
 import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull;
 import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;
 import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
@@ -70,20 +69,26 @@ import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER;
 import static org.elasticsearch.xpack.esql.expression.Foldables.TypeResolutionValidator.forPreOptimizationValidation;
 import static org.elasticsearch.xpack.esql.expression.Foldables.resolveTypeQuery;
 
-public class Knn extends FullTextFunction implements OptionalArgument, VectorFunction, PostAnalysisPlanVerificationAware {
-    private final Logger log = LogManager.getLogger(getClass());
+public class Knn extends FullTextFunction
+    implements
+        OptionalArgument,
+        VectorFunction,
+        PostAnalysisPlanVerificationAware,
+        PostOptimizationVerificationAware {
 
     public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Knn", Knn::readFrom);
 
     private final Expression field;
     // k is not serialized as it's already included in the query builder on the rewrite step before being sent to data nodes
-    private final transient Expression k;
+    private final transient Integer k;
     private final Expression options;
     // Expressions to be used as prefilters in knn query
     private final List<Expression> filterExpressions;
 
+    public static final String MIN_CANDIDATES_OPTION = "min_candidates";
+
     public static final Map<String, DataType> ALLOWED_OPTIONS = Map.ofEntries(
-        entry(NUM_CANDS_FIELD.getPreferredName(), INTEGER),
+        entry(MIN_CANDIDATES_OPTION, INTEGER),
         entry(VECTOR_SIMILARITY_FIELD.getPreferredName(), FLOAT),
         entry(BOOST_FIELD.getPreferredName(), FLOAT),
         entry(KnnQuery.RESCORE_OVERSAMPLE_FIELD, FLOAT)
@@ -105,13 +110,6 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
             type = { "dense_vector" },
             description = "Vector value to find top nearest neighbours for."
         ) Expression query,
-        @Param(
-            name = "k",
-            type = { "integer" },
-            description = "The number of nearest neighbors to return from each shard. "
-                + "Elasticsearch collects k results from each shard, then merges them to find the global top results. "
-                + "This value must be less than or equal to num_candidates."
-        ) Expression k,
         @MapParam(
             name = "options",
             params = {
@@ -123,12 +121,13 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
                         + "Defaults to 1.0."
                 ),
                 @MapParam.MapParamEntry(
-                    name = "num_candidates",
+                    name = "min_candidates",
                     type = "integer",
                     valueHint = { "10" },
-                    description = "The number of nearest neighbor candidates to consider per shard while doing knn search. "
-                        + "Cannot exceed 10,000. Increasing num_candidates tends to improve the accuracy of the final results. "
-                        + "Defaults to 1.5 * k"
+                    description = "The minimum number of nearest neighbor candidates to consider per shard while doing knn search. "
+                        + " KNN may use a higher number of candidates in case the query can't use a approximate results. "
+                        + "Cannot exceed 10,000. Increasing min_candidates tends to improve the accuracy of the final results. "
+                        + "Defaults to 1.5 * LIMIT used for the query."
                 ),
                 @MapParam.MapParamEntry(
                     name = "similarity",
@@ -150,32 +149,29 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
             optional = true
         ) Expression options
     ) {
-        this(source, field, query, k, options, null, List.of());
+        this(source, field, query, options, null, null, List.of());
     }
 
     public Knn(
         Source source,
         Expression field,
         Expression query,
-        Expression k,
         Expression options,
+        Integer k,
         QueryBuilder queryBuilder,
         List<Expression> filterExpressions
     ) {
-        super(source, query, expressionList(field, query, k, options), queryBuilder);
+        super(source, query, expressionList(field, query, options), queryBuilder);
         this.field = field;
         this.k = k;
         this.options = options;
         this.filterExpressions = filterExpressions;
     }
 
-    private static List<Expression> expressionList(Expression field, Expression query, Expression k, Expression options) {
+    private static List<Expression> expressionList(Expression field, Expression query, Expression options) {
         List<Expression> result = new ArrayList<>();
         result.add(field);
         result.add(query);
-        if (k != null) {
-            result.add(k);
-        }
         if (options != null) {
             result.add(options);
         }
@@ -186,7 +182,7 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
         return field;
     }
 
-    public Expression k() {
+    public Integer k() {
         return k;
     }
 
@@ -205,7 +201,7 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
 
     @Override
     protected TypeResolution resolveParams() {
-        return resolveField().and(resolveQuery()).and(resolveK()).and(Options.resolve(options(), source(), FOURTH, ALLOWED_OPTIONS));
+        return resolveField().and(resolveQuery()).and(Options.resolve(options(), source(), THIRD, ALLOWED_OPTIONS));
     }
 
     private TypeResolution resolveField() {
@@ -225,14 +221,9 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
         return TypeResolution.TYPE_RESOLVED;
     }
 
-    private TypeResolution resolveK() {
-        if (k == null) {
-            // Function has already been rewritten and included in QueryBuilder - otherwise parsing would have failed
-            return TypeResolution.TYPE_RESOLVED;
-        }
-
-        return isType(k(), dt -> dt == INTEGER, sourceText(), THIRD, "integer").and(isFoldable(k(), sourceText(), THIRD))
-            .and(isNotNull(k(), sourceText(), THIRD));
+    public Knn replaceK(Integer k) {
+        Check.notNull(k, "k must not be null");
+        return new Knn(source(), field(), query(), options(), k, queryBuilder(), filterExpressions());
     }
 
     public List<Number> queryAsObject() {
@@ -246,16 +237,9 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
         throw new EsqlIllegalArgumentException(format(null, "Query value must be a list of numbers in [{}], found [{}]", source(), query));
     }
 
-    int getKIntValue() {
-        if (k() instanceof Literal literal) {
-            return (int) (Number) literal.value();
-        }
-        throw new EsqlIllegalArgumentException(format(null, "K value must be a constant integer in [{}], found [{}]", source(), k()));
-    }
-
     @Override
     public Expression replaceQueryBuilder(QueryBuilder queryBuilder) {
-        return new Knn(source(), field(), query(), k(), options(), queryBuilder, filterExpressions());
+        return new Knn(source(), field(), query(), options(), k(), queryBuilder, filterExpressions());
     }
 
     @Override
@@ -271,37 +255,39 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
 
     @Override
     protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) {
+        assert k() != null : "Knn function must have a k value set before translation";
         var fieldAttribute = Match.fieldAsFieldAttribute(field());
 
         Check.notNull(fieldAttribute, "Knn must have a field attribute as the first argument");
         String fieldName = getNameFromFieldAttribute(fieldAttribute);
-        List<Number> queryFolded = queryAsObject();
-        float[] queryAsFloats = new float[queryFolded.size()];
-        for (int i = 0; i < queryFolded.size(); i++) {
-            queryAsFloats[i] = queryFolded.get(i).floatValue();
-        }
-        int kValue = getKIntValue();
-
-        Map<String, Object> opts = queryOptions();
-        opts.put(K_FIELD.getPreferredName(), kValue);
+        float[] queryAsFloats = queryAsFloats();
 
         List<QueryBuilder> filterQueries = new ArrayList<>();
         for (Expression filterExpression : filterExpressions()) {
             if (filterExpression instanceof TranslationAware translationAware) {
                 // We can only translate filter expressions that are translatable. In case any is not translatable,
-                // Knn won't be pushed down as it will not be translatable so it's safe not to translate all filters and check them
-                // when creating an evaluator for the non-pushed down query
+                // Knn won't be pushed down so it's safe not to translate all filters and check them when creating an evaluator
+                // for the non-pushed down query
                 if (translationAware.translatable(pushdownPredicates) == Translatable.YES) {
                     filterQueries.add(handler.asQuery(pushdownPredicates, filterExpression).toQueryBuilder());
                 }
             }
         }
 
-        return new KnnQuery(source(), fieldName, queryAsFloats, opts, filterQueries);
+        return new KnnQuery(source(), fieldName, queryAsFloats, k(), queryOptions(), filterQueries);
+    }
+
+    private float[] queryAsFloats() {
+        List<Number> queryFolded = queryAsObject();
+        float[] queryAsFloats = new float[queryFolded.size()];
+        for (int i = 0; i < queryFolded.size(); i++) {
+            queryAsFloats[i] = queryFolded.get(i).floatValue();
+        }
+        return queryAsFloats;
     }
 
     public Expression withFilters(List<Expression> filterExpressions) {
-        return new Knn(source(), field(), query(), k(), options(), queryBuilder(), filterExpressions);
+        return new Knn(source(), field(), query(), options(), k(), queryBuilder(), filterExpressions);
     }
 
     private Map<String, Object> queryOptions() throws InvalidArgumentException {
@@ -312,6 +298,17 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
         return options;
     }
 
+    protected QueryBuilder evaluatorQueryBuilder() {
+        // Either we couldn't push down due to non-pushable filters, or because it's part of a disjuncion.
+        // Uses a nearest neighbors exact query instead of an approximate one
+        var fieldAttribute = Match.fieldAsFieldAttribute(field());
+        Check.notNull(fieldAttribute, "Knn must have a field attribute as the first argument");
+        String fieldName = getNameFromFieldAttribute(fieldAttribute);
+        Map<String, Object> opts = queryOptions();
+
+        return new ExactKnnQueryBuilder(VectorData.fromFloats(queryAsFloats()), fieldName, (Float) opts.get(VECTOR_SIMILARITY_FIELD));
+    }
+
     @Override
     public BiConsumer<LogicalPlan, Failures> postAnalysisPlanVerification() {
         return (plan, failures) -> {
@@ -320,14 +317,24 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
         };
     }
 
+    @Override
+    public void postOptimizationVerification(Failures failures) {
+        // Check that a k has been set
+        if (k() == null) {
+            failures.add(
+                Failure.fail(this, "Knn function must be used with a LIMIT clause after it to set the number of nearest neighbors to find")
+            );
+        }
+    }
+
     @Override
     public Expression replaceChildren(List<Expression> newChildren) {
         return new Knn(
             source(),
             newChildren.get(0),
             newChildren.get(1),
-            newChildren.get(2),
-            newChildren.size() > 3 ? newChildren.get(3) : null,
+            newChildren.size() > 2 ? newChildren.get(2) : null,
+            k(),
             queryBuilder(),
             filterExpressions()
         );
@@ -335,7 +342,7 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
 
     @Override
     protected NodeInfo<? extends Expression> info() {
-        return NodeInfo.create(this, Knn::new, field(), query(), k(), options(), queryBuilder(), filterExpressions());
+        return NodeInfo.create(this, Knn::new, field(), query(), options(), k(), queryBuilder(), filterExpressions());
     }
 
     @Override

+ 1 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java

@@ -27,7 +27,7 @@ public final class VectorWritables {
     public static List<NamedWriteableRegistry.Entry> getNamedWritables() {
         List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
 
-        if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) {
+        if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) {
             entries.add(Knn.ENTRY);
         }
         if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {

+ 2 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java

@@ -44,6 +44,7 @@ import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownEval;
 import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownInferencePlan;
 import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownJoinPastProject;
 import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownRegexExtract;
+import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushLimitToKnn;
 import org.elasticsearch.xpack.esql.optimizer.rules.logical.RemoveStatsOverride;
 import org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceAggregateAggExpressionWithEval;
 import org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceAggregateNestedExpressionWithEval;
@@ -192,6 +193,7 @@ public class LogicalPlanOptimizer extends ParameterizedRuleExecutor<LogicalPlan,
             new PruneColumns(),
             new PruneLiteralsInOrderBy(),
             new PushDownAndCombineLimits(),
+            new PushLimitToKnn(),
             new PushDownAndCombineFilters(),
             new PushDownConjunctionsToKnnPrefilters(),
             new PushDownAndCombineSample(),

+ 69 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushLimitToKnn.java

@@ -0,0 +1,69 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.optimizer.rules.logical;
+
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.util.Holder;
+import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
+import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
+import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
+import org.elasticsearch.xpack.esql.plan.logical.Filter;
+import org.elasticsearch.xpack.esql.plan.logical.Limit;
+import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
+import org.elasticsearch.xpack.esql.plan.logical.TopN;
+import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
+
+/**
+ * Traverses the logical plan and pushes down the limit to the KNN function(s) in filter expressions, so KNN can use
+ * it to set k if not specified.
+ */
+public class PushLimitToKnn extends OptimizerRules.ParameterizedOptimizerRule<Limit, LogicalOptimizerContext> {
+
+    public PushLimitToKnn() {
+        super(OptimizerRules.TransformDirection.DOWN);
+    }
+
+    @Override
+    public LogicalPlan rule(Limit limit, LogicalOptimizerContext ctx) {
+        Holder<Boolean> breakerReached = new Holder<>(false);
+        Holder<Boolean> firstLimit = new Holder<>(false);
+        return limit.transformDown(plan -> {
+            if (breakerReached.get()) {
+                // We reached a breaker and don't want to continue processing
+                return plan;
+            }
+            if (plan instanceof Filter filter) {
+                Expression limitAppliedExpression = limitFilterExpressions(filter.condition(), limit, ctx);
+                if (limitAppliedExpression.equals(filter.condition()) == false) {
+                    return filter.with(limitAppliedExpression);
+                }
+            } else if (plan instanceof Limit) {
+                // Break if it's not the initial limit
+                breakerReached.set(firstLimit.get());
+                firstLimit.set(true);
+            } else if (plan instanceof TopN || plan instanceof Rerank || plan instanceof Aggregate) {
+                breakerReached.set(true);
+            }
+
+            return plan;
+        });
+    }
+
+    /**
+     * Applies a limit to the filter expressions of a condition. Some filter expressions, such as KNN function,
+     * can be optimized by applying the limit directly to them.
+     */
+    private Expression limitFilterExpressions(Expression condition, Limit limit, LogicalOptimizerContext ctx) {
+        return condition.transformDown(exp -> {
+            if (exp instanceof Knn knn) {
+                return knn.replaceK((Integer) limit.limit().fold(ctx.foldCtx()));
+            }
+            return exp;
+        });
+    }
+}

+ 17 - 7
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java

@@ -12,6 +12,7 @@ import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
 import org.elasticsearch.search.vectors.RescoreVectorBuilder;
 import org.elasticsearch.xpack.esql.core.querydsl.query.Query;
 import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
 
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -20,8 +21,6 @@ import java.util.Map;
 import java.util.Objects;
 
 import static org.elasticsearch.index.query.AbstractQueryBuilder.BOOST_FIELD;
-import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.K_FIELD;
-import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.NUM_CANDS_FIELD;
 import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.VECTOR_SIMILARITY_FIELD;
 
 public class KnnQuery extends Query {
@@ -32,9 +31,12 @@ public class KnnQuery extends Query {
     private final List<QueryBuilder> filterQueries;
 
     public static final String RESCORE_OVERSAMPLE_FIELD = "rescore_oversample";
+    private final Integer k;
 
-    public KnnQuery(Source source, String field, float[] query, Map<String, Object> options, List<QueryBuilder> filterQueries) {
+    public KnnQuery(Source source, String field, float[] query, Integer k, Map<String, Object> options, List<QueryBuilder> filterQueries) {
         super(source);
+        assert k != null && k > 0 : "k must be a positive integer, but was: " + k;
+        this.k = k;
         assert options != null;
         this.field = field;
         this.query = query;
@@ -44,16 +46,24 @@ public class KnnQuery extends Query {
 
     @Override
     protected QueryBuilder asBuilder() {
-        Integer k = (Integer) options.get(K_FIELD.getPreferredName());
-        Integer numCands = (Integer) options.get(NUM_CANDS_FIELD.getPreferredName());
         RescoreVectorBuilder rescoreVectorBuilder = null;
         Float oversample = (Float) options.get(RESCORE_OVERSAMPLE_FIELD);
         if (oversample != null) {
             rescoreVectorBuilder = new RescoreVectorBuilder(oversample);
         }
         Float vectorSimilarity = (Float) options.get(VECTOR_SIMILARITY_FIELD.getPreferredName());
-
-        KnnVectorQueryBuilder queryBuilder = new KnnVectorQueryBuilder(field, query, k, numCands, rescoreVectorBuilder, vectorSimilarity);
+        Integer minCandidates = (Integer) options.get(Knn.MIN_CANDIDATES_OPTION);
+        int adjustedK = Math.max(k, minCandidates == null ? 0 : minCandidates);
+        minCandidates = minCandidates == null ? null : Math.max(minCandidates, adjustedK);
+
+        KnnVectorQueryBuilder queryBuilder = new KnnVectorQueryBuilder(
+            field,
+            query,
+            adjustedK,
+            minCandidates,
+            rescoreVectorBuilder,
+            vectorSimilarity
+        );
         for (QueryBuilder filter : filterQueries) {
             queryBuilder.addFilterQuery(filter);
         }

+ 1 - 1
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java

@@ -305,7 +305,7 @@ public class CsvTests extends ESTestCase {
             );
             assumeFalse(
                 "can't use KNN function in csv tests",
-                testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.KNN_FUNCTION_V3.capabilityName())
+                testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.KNN_FUNCTION_V4.capabilityName())
             );
             assumeFalse(
                 "lookup join disabled for csv tests",

+ 2 - 3
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java

@@ -2349,20 +2349,19 @@ public class AnalyzerTests extends ESTestCase {
 
     public void testDenseVectorImplicitCastingKnn() {
         assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled());
-        assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
+        assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled());
 
         checkDenseVectorCastingKnn("float_vector");
     }
 
     private static void checkDenseVectorCastingKnn(String fieldName) {
         var plan = analyze(String.format(Locale.ROOT, """
-            from test | where knn(%s, [0.342, 0.164, 0.234], 10)
+            from test | where knn(%s, [0.342, 0.164, 0.234])
             """, fieldName), "mapping-dense_vector.json");
 
         var limit = as(plan, Limit.class);
         var filter = as(limit.child(), Filter.class);
         var knn = as(filter.condition(), Knn.class);
-        var field = knn.field();
         var queryVector = as(knn.query(), Literal.class);
         assertEquals(DataType.DENSE_VECTOR, queryVector.dataType());
         assertThat(queryVector.value(), equalTo(List.of(0.342f, 0.164f, 0.234f)));

+ 16 - 17
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java

@@ -1268,8 +1268,8 @@ public class VerifierTests extends ESTestCase {
             checkFieldBasedWithNonIndexedColumn("Term", "term(text, \"cat\")", "function");
             checkFieldBasedFunctionNotAllowedAfterCommands("Term", "function", "term(title, \"Meditation\")");
         }
-        if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) {
-            checkFieldBasedFunctionNotAllowedAfterCommands("KNN", "function", "knn(vector, [1, 2, 3], 10)");
+        if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) {
+            checkFieldBasedFunctionNotAllowedAfterCommands("KNN", "function", "knn(vector, [1, 2, 3])");
         }
     }
 
@@ -1401,8 +1401,8 @@ public class VerifierTests extends ESTestCase {
         if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) {
             checkFullTextFunctionsOnlyAllowedInWhere("MultiMatch", "multi_match(\"Meditation\", title, body)", "function");
         }
-        if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) {
-            checkFullTextFunctionsOnlyAllowedInWhere("KNN", "knn(vector, [0, 1, 2], 10)", "function");
+        if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) {
+            checkFullTextFunctionsOnlyAllowedInWhere("KNN", "knn(vector, [0, 1, 2])", "function");
         }
 
     }
@@ -1456,8 +1456,8 @@ public class VerifierTests extends ESTestCase {
         if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) {
             checkWithFullTextFunctionsDisjunctions("term(title, \"Meditation\")");
         }
-        if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) {
-            checkWithFullTextFunctionsDisjunctions("knn(vector, [1, 2, 3], 10)");
+        if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) {
+            checkWithFullTextFunctionsDisjunctions("knn(vector, [1, 2, 3])");
         }
     }
 
@@ -1521,8 +1521,8 @@ public class VerifierTests extends ESTestCase {
         if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) {
             checkFullTextFunctionsWithNonBooleanFunctions("Term", "term(title, \"Meditation\")", "function");
         }
-        if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) {
-            checkFullTextFunctionsWithNonBooleanFunctions("KNN", "knn(vector, [1, 2, 3], 10)", "function");
+        if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) {
+            checkFullTextFunctionsWithNonBooleanFunctions("KNN", "knn(vector, [1, 2, 3])", "function");
         }
     }
 
@@ -1592,7 +1592,7 @@ public class VerifierTests extends ESTestCase {
         if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) {
             testFullTextFunctionTargetsExistingField("term(fist_name, \"Meditation\")");
         }
-        if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) {
+        if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) {
             testFullTextFunctionTargetsExistingField("knn(vector, [0, 1, 2], 10)");
         }
     }
@@ -2189,8 +2189,8 @@ public class VerifierTests extends ESTestCase {
         if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) {
             checkOptionDataTypes(MultiMatch.OPTIONS, "FROM test | WHERE MULTI_MATCH(\"Jean\", title, body, {\"%s\": %s})");
         }
-        if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) {
-            checkOptionDataTypes(Knn.ALLOWED_OPTIONS, "FROM test | WHERE KNN(vector, [0.1, 0.2, 0.3], 10, {\"%s\": %s})");
+        if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) {
+            checkOptionDataTypes(Knn.ALLOWED_OPTIONS, "FROM test | WHERE KNN(vector, [0.1, 0.2, 0.3], {\"%s\": %s})");
         }
     }
 
@@ -2282,10 +2282,9 @@ public class VerifierTests extends ESTestCase {
             checkFullTextFunctionNullArgs("term(null, \"query\")", "first");
             checkFullTextFunctionNullArgs("term(title, null)", "second");
         }
-        if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) {
-            checkFullTextFunctionNullArgs("knn(null, [0, 1, 2], 10)", "first");
-            checkFullTextFunctionNullArgs("knn(vector, null, 10)", "second");
-            checkFullTextFunctionNullArgs("knn(vector, [0, 1, 2], null)", "third");
+        if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) {
+            checkFullTextFunctionNullArgs("knn(null, [0, 1, 2])", "first");
+            checkFullTextFunctionNullArgs("knn(vector, null)", "second");
         }
     }
 
@@ -2314,8 +2313,8 @@ public class VerifierTests extends ESTestCase {
         if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) {
             checkFullTextFunctionsInStats("multi_match(\"Meditation\", title, body)");
         }
-        if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) {
-            checkFullTextFunctionsInStats("knn(vector, [0, 1, 2], 10)");
+        if (EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled()) {
+            checkFullTextFunctionsInStats("knn(vector, [0, 1, 2])");
         }
     }
 

+ 2 - 2
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java

@@ -52,7 +52,7 @@ public class KnnTests extends AbstractFunctionTestCase {
 
     @Before
     public void checkCapability() {
-        assumeTrue("KNN is not enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
+        assumeTrue("KNN is not enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled());
     }
 
     private static List<TestCaseSupplier> testCaseSuppliers() {
@@ -121,7 +121,7 @@ public class KnnTests extends AbstractFunctionTestCase {
 
     @Override
     protected Expression build(Source source, List<Expression> args) {
-        Knn knn = new Knn(source, args.get(0), args.get(1), args.get(2), args.size() > 3 ? args.get(3) : null);
+        Knn knn = new Knn(source, args.get(0), args.get(1), args.size() > 2 ? args.get(2) : null);
         // We need to add the QueryBuilder to the match expression, as it is used to implement equals() and hashCode() and
         // thus test the serialization methods. But we can only do this if the parameters make sense .
         if (args.get(0) instanceof FieldAttribute && args.get(1).foldable()) {

+ 2 - 1
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/AbstractLogicalPlanOptimizerTests.java

@@ -32,6 +32,7 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolutio
 import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
 import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext;
 import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning;
+import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.defaultInferenceResolution;
 import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.defaultLookupResolution;
 import static org.elasticsearch.xpack.esql.core.type.DataType.KEYWORD;
 
@@ -118,7 +119,7 @@ public abstract class AbstractLogicalPlanOptimizerTests extends ESTestCase {
                 new EsqlFunctionRegistry(),
                 getIndexResultTypes,
                 enrichResolution,
-                emptyInferenceResolution()
+                defaultInferenceResolution()
             ),
             TEST_VERIFIER
         );

+ 98 - 41
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java

@@ -1376,12 +1376,12 @@ public class LocalPhysicalPlanOptimizerTests extends MapperServiceTestCase {
 
     public void testKnnOptionsPushDown() {
         assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled());
-        assumeTrue("knn capability not available", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
+        assumeTrue("knn capability not available", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled());
 
         String query = """
             from test
-            | where KNN(dense_vector, [0.1, 0.2, 0.3], 5,
-                { "similarity": 0.001, "num_candidates": 10, "rescore_oversample": 7, "boost": 3.5 })
+            | where KNN(dense_vector, [0.1, 0.2, 0.3],
+                { "similarity": 0.001, "min_candidates": 5000, "rescore_oversample": 7, "boost": 3.5 })
             """;
         var analyzer = makeAnalyzer("mapping-all-types.json");
         var plan = plannerOptimizer.plan(query, IS_SV_STATS, analyzer);
@@ -1392,12 +1392,69 @@ public class LocalPhysicalPlanOptimizerTests extends MapperServiceTestCase {
         var expectedQuery = new KnnVectorQueryBuilder(
             "dense_vector",
             new float[] { 0.1f, 0.2f, 0.3f },
-            5,
-            10,
+            5000,
+            5000,
             new RescoreVectorBuilder(7),
             0.001f
         ).boost(3.5f);
-        assertThat(expectedQuery.toString(), is(planStr.get()));
+        assertEquals(expectedQuery.toString(), planStr.get());
+    }
+
+    public void testKnnUsesLimitForK() {
+        assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled());
+        assumeTrue("knn capability not available", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled());
+
+        String query = """
+            from test
+            | where KNN(dense_vector, [0.1, 0.2, 0.3])
+            | limit 10
+            """;
+        var analyzer = makeAnalyzer("mapping-all-types.json");
+        var plan = plannerOptimizer.plan(query, IS_SV_STATS, analyzer);
+
+        AtomicReference<String> planStr = new AtomicReference<>();
+        plan.forEachDown(EsQueryExec.class, result -> planStr.set(result.query().toString()));
+
+        var expectedQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 0.1f, 0.2f, 0.3f }, 10, null, null, null);
+        assertEquals(expectedQuery.toString(), planStr.get());
+    }
+
+    public void testKnnKAndMinCandidatesLowerK() {
+        assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled());
+        assumeTrue("knn capability not available", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled());
+
+        String query = """
+            from test
+            | where KNN(dense_vector, [0.1, 0.2, 0.3], {"min_candidates": 50})
+            | limit 10
+            """;
+        var analyzer = makeAnalyzer("mapping-all-types.json");
+        var plan = plannerOptimizer.plan(query, IS_SV_STATS, analyzer);
+
+        AtomicReference<String> planStr = new AtomicReference<>();
+        plan.forEachDown(EsQueryExec.class, result -> planStr.set(result.query().toString()));
+
+        var expectedQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 0.1f, 0.2f, 0.3f }, 50, 50, null, null);
+        assertEquals(expectedQuery.toString(), planStr.get());
+    }
+
+    public void testKnnKAndMinCandidatesHigherK() {
+        assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled());
+        assumeTrue("knn capability not available", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled());
+
+        String query = """
+            from test
+            | where KNN(dense_vector, [0.1, 0.2, 0.3], {"min_candidates": 10})
+            | limit 50
+            """;
+        var analyzer = makeAnalyzer("mapping-all-types.json");
+        var plan = plannerOptimizer.plan(query, IS_SV_STATS, analyzer);
+
+        AtomicReference<String> planStr = new AtomicReference<>();
+        plan.forEachDown(EsQueryExec.class, result -> planStr.set(result.query().toString()));
+
+        var expectedQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 0.1f, 0.2f, 0.3f }, 50, 50, null, null);
+        assertEquals(expectedQuery.toString(), planStr.get());
     }
 
     /**
@@ -1842,11 +1899,11 @@ public class LocalPhysicalPlanOptimizerTests extends MapperServiceTestCase {
     }
 
     public void testKnnPrefilters() {
-        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
+        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled());
 
         String query = """
             from test
-            | where knn(dense_vector, [0, 1, 2], 10) and integer > 10
+            | where knn(dense_vector, [0, 1, 2]) and integer > 10
             """;
         var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json"));
 
@@ -1859,12 +1916,12 @@ public class LocalPhysicalPlanOptimizerTests extends MapperServiceTestCase {
             query,
             unscore(rangeQuery("integer").gt(10)),
             "integer",
-            new Source(2, 45, "integer > 10")
+            new Source(2, 41, "integer > 10")
         );
         KnnVectorQueryBuilder expectedKnnQueryBuilder = new KnnVectorQueryBuilder(
             "dense_vector",
             new float[] { 0, 1, 2 },
-            10,
+            1000,
             null,
             null,
             null
@@ -1874,11 +1931,11 @@ public class LocalPhysicalPlanOptimizerTests extends MapperServiceTestCase {
     }
 
     public void testKnnPrefiltersWithMultipleFilters() {
-        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
+        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled());
 
         String query = """
             from test
-            | where knn(dense_vector, [0, 1, 2], 10)
+            | where knn(dense_vector, [0, 1, 2])
             | where integer > 10
             | where keyword == "test"
             """;
@@ -1900,7 +1957,7 @@ public class LocalPhysicalPlanOptimizerTests extends MapperServiceTestCase {
         KnnVectorQueryBuilder expectedKnnQueryBuilder = new KnnVectorQueryBuilder(
             "dense_vector",
             new float[] { 0, 1, 2 },
-            10,
+            1000,
             null,
             null,
             null
@@ -1910,11 +1967,11 @@ public class LocalPhysicalPlanOptimizerTests extends MapperServiceTestCase {
     }
 
     public void testPushDownConjunctionsToKnnPrefilter() {
-        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
+        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled());
 
         String query = """
             from test
-            | where knn(dense_vector, [0, 1, 2], 10) and integer > 10
+            | where knn(dense_vector, [0, 1, 2]) and integer > 10
             """;
         var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json"));
 
@@ -1929,13 +1986,13 @@ public class LocalPhysicalPlanOptimizerTests extends MapperServiceTestCase {
             query,
             unscore(rangeQuery("integer").gt(10)),
             "integer",
-            new Source(2, 45, "integer > 10")
+            new Source(2, 41, "integer > 10")
         );
 
         KnnVectorQueryBuilder expectedKnnQueryBuilder = new KnnVectorQueryBuilder(
             "dense_vector",
             new float[] { 0, 1, 2 },
-            10,
+            1000,
             null,
             null,
             null
@@ -1947,11 +2004,11 @@ public class LocalPhysicalPlanOptimizerTests extends MapperServiceTestCase {
     }
 
     public void testPushDownNegatedConjunctionsToKnnPrefilter() {
-        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
+        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled());
 
         String query = """
             from test
-            | where knn(dense_vector, [0, 1, 2], 10) and NOT integer > 10
+            | where knn(dense_vector, [0, 1, 2]) and NOT integer > 10
             """;
         var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json"));
 
@@ -1966,13 +2023,13 @@ public class LocalPhysicalPlanOptimizerTests extends MapperServiceTestCase {
             query,
             unscore(boolQuery().mustNot(unscore(rangeQuery("integer").gt(10)))),
             "integer",
-            new Source(2, 45, "NOT integer > 10")
+            new Source(2, 41, "NOT integer > 10")
         );
 
         KnnVectorQueryBuilder expectedKnnQueryBuilder = new KnnVectorQueryBuilder(
             "dense_vector",
             new float[] { 0, 1, 2 },
-            10,
+            1000,
             null,
             null,
             null
@@ -1984,11 +2041,11 @@ public class LocalPhysicalPlanOptimizerTests extends MapperServiceTestCase {
     }
 
     public void testNotPushDownDisjunctionsToKnnPrefilter() {
-        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
+        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled());
 
         String query = """
             from test
-            | where knn(dense_vector, [0, 1, 2], 10) or integer > 10
+            | where knn(dense_vector, [0, 1, 2]) or integer > 10
             """;
         var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json"));
 
@@ -1999,12 +2056,12 @@ public class LocalPhysicalPlanOptimizerTests extends MapperServiceTestCase {
         var queryExec = as(field.child(), EsQueryExec.class);
 
         // The disjunction should not be pushed down to the KNN query
-        KnnVectorQueryBuilder knnQueryBuilder = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 10, null, null, null);
+        KnnVectorQueryBuilder knnQueryBuilder = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 1000, null, null, null);
         QueryBuilder rangeQueryBuilder = wrapWithSingleQuery(
             query,
             unscore(rangeQuery("integer").gt(10)),
             "integer",
-            new Source(2, 44, "integer > 10")
+            new Source(2, 40, "integer > 10")
         );
 
         var expectedQuery = boolQuery().should(knnQueryBuilder).should(rangeQueryBuilder);
@@ -2013,11 +2070,11 @@ public class LocalPhysicalPlanOptimizerTests extends MapperServiceTestCase {
     }
 
     public void testNotPushDownKnnWithNonPushablePrefilters() {
-        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
+        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled());
 
         String query = """
             from test
-            | where ((knn(dense_vector, [0, 1, 2], 10) AND integer > 10) and ((keyword == "test") or length(text) > 10))
+            | where ((knn(dense_vector, [0, 1, 2]) AND integer > 10) and ((keyword == "test") or length(text) > 10))
             """;
         var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json"));
 
@@ -2040,19 +2097,19 @@ public class LocalPhysicalPlanOptimizerTests extends MapperServiceTestCase {
             query,
             unscore(rangeQuery("integer").gt(10)),
             "integer",
-            new Source(2, 47, "integer > 10")
+            new Source(2, 43, "integer > 10")
         );
 
         assertEquals(integerGtQuery.toString(), queryExec.query().toString());
     }
 
     public void testPushDownComplexNegationsToKnnPrefilter() {
-        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
+        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled());
 
         String query = """
             from test
-            | where ((knn(dense_vector, [0, 1, 2], 10) or NOT integer > 10)
-              and NOT ((keyword == "test") or knn(dense_vector, [4, 5, 6], 10)))
+            | where ((knn(dense_vector, [0, 1, 2]) or NOT integer > 10)
+              and NOT ((keyword == "test") or knn(dense_vector, [4, 5, 6])))
             """;
         var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json"));
 
@@ -2072,18 +2129,18 @@ public class LocalPhysicalPlanOptimizerTests extends MapperServiceTestCase {
             query,
             unscore(boolQuery().mustNot(unscore(termQuery("keyword", "test")))),
             "keyword",
-            new Source(3, 6, "NOT ((keyword == \"test\") or knn(dense_vector, [4, 5, 6], 10))")
+            new Source(3, 6, "NOT ((keyword == \"test\") or knn(dense_vector, [4, 5, 6]))")
         );
 
         QueryBuilder notIntegerGt10 = wrapWithSingleQuery(
             query,
             unscore(boolQuery().mustNot(unscore(rangeQuery("integer").gt(10)))),
             "integer",
-            new Source(2, 46, "NOT integer > 10")
+            new Source(2, 42, "NOT integer > 10")
         );
 
-        KnnVectorQueryBuilder firstKnn = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 10, null, null, null);
-        KnnVectorQueryBuilder secondKnn = new KnnVectorQueryBuilder("dense_vector", new float[] { 4, 5, 6 }, 10, null, null, null);
+        KnnVectorQueryBuilder firstKnn = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 1000, null, null, null);
+        KnnVectorQueryBuilder secondKnn = new KnnVectorQueryBuilder("dense_vector", new float[] { 4, 5, 6 }, 1000, null, null, null);
 
         firstKnn.addFilterQuery(notKeywordFilter);
         secondKnn.addFilterQuery(notIntegerGt10);
@@ -2097,11 +2154,11 @@ public class LocalPhysicalPlanOptimizerTests extends MapperServiceTestCase {
     }
 
     public void testMultipleKnnQueriesInPrefilters() {
-        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
+        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled());
 
         String query = """
             from test
-            | where ((knn(dense_vector, [0, 1, 2], 10) or integer > 10) and ((keyword == "test") or knn(dense_vector, [4, 5, 6], 10)))
+            | where ((knn(dense_vector, [0, 1, 2]) or integer > 10) and ((keyword == "test") or knn(dense_vector, [4, 5, 6])))
             """;
         var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json"));
 
@@ -2111,24 +2168,24 @@ public class LocalPhysicalPlanOptimizerTests extends MapperServiceTestCase {
         var field = as(project.child(), FieldExtractExec.class);
         var queryExec = as(field.child(), EsQueryExec.class);
 
-        KnnVectorQueryBuilder firstKnnQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 10, null, null, null);
+        KnnVectorQueryBuilder firstKnnQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 1000, null, null, null);
         // Integer range query (right side of first OR)
         QueryBuilder integerRangeQuery = wrapWithSingleQuery(
             query,
             unscore(rangeQuery("integer").gt(10)),
             "integer",
-            new Source(2, 46, "integer > 10")
+            new Source(2, 42, "integer > 10")
         );
 
         // Second KNN query (right side of second OR)
-        KnnVectorQueryBuilder secondKnnQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 4, 5, 6 }, 10, null, null, null);
+        KnnVectorQueryBuilder secondKnnQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 4, 5, 6 }, 1000, null, null, null);
 
         // Keyword term query (left side of second OR)
         QueryBuilder keywordQuery = wrapWithSingleQuery(
             query,
             unscore(termQuery("keyword", "test")),
             "keyword",
-            new Source(2, 66, "keyword == \"test\"")
+            new Source(2, 62, "keyword == \"test\"")
         );
 
         // First OR (knn1 OR integer > 10)

+ 162 - 12
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java

@@ -8499,11 +8499,11 @@ public class LogicalPlanOptimizerTests extends AbstractLogicalPlanOptimizerTests
     }
 
     public void testPushDownConjunctionsToKnnPrefilter() {
-        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
+        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled());
 
         var query = """
             from test
-            | where knn(dense_vector, [0, 1, 2], 10) and integer > 10
+            | where knn(dense_vector, [0, 1, 2]) and integer > 10
             """;
         var optimized = planTypes(query);
 
@@ -8519,11 +8519,11 @@ public class LogicalPlanOptimizerTests extends AbstractLogicalPlanOptimizerTests
     }
 
     public void testPushDownMultipleFiltersToKnnPrefilter() {
-        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
+        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled());
 
         var query = """
             from test
-            | where knn(dense_vector, [0, 1, 2], 10)
+            | where knn(dense_vector, [0, 1, 2])
             | where integer > 10
             | where keyword == "test"
             """;
@@ -8542,11 +8542,11 @@ public class LogicalPlanOptimizerTests extends AbstractLogicalPlanOptimizerTests
     }
 
     public void testNotPushDownDisjunctionsToKnnPrefilter() {
-        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
+        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled());
 
         var query = """
             from test
-            | where knn(dense_vector, [0, 1, 2], 10) or integer > 10
+            | where knn(dense_vector, [0, 1, 2]) or integer > 10
             """;
         var optimized = planTypes(query);
 
@@ -8559,7 +8559,7 @@ public class LogicalPlanOptimizerTests extends AbstractLogicalPlanOptimizerTests
     }
 
     public void testPushDownConjunctionsAndNotDisjunctionsToKnnPrefilter() {
-        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
+        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled());
 
         /*
             and
@@ -8576,7 +8576,7 @@ public class LogicalPlanOptimizerTests extends AbstractLogicalPlanOptimizerTests
         var query = """
             from test
             | where
-                 ((knn(dense_vector, [0, 1, 2], 10) or integer > 10) and keyword == "test") and ((short < 5) or (double > 5.0))
+                 ((knn(dense_vector, [0, 1, 2]) or integer > 10) and keyword == "test") and ((short < 5) or (double > 5.0))
             """;
         var optimized = planTypes(query);
 
@@ -8594,7 +8594,7 @@ public class LogicalPlanOptimizerTests extends AbstractLogicalPlanOptimizerTests
     }
 
     public void testMorePushDownConjunctionsAndNotDisjunctionsToKnnPrefilter() {
-        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
+        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled());
 
         /*
             or
@@ -8611,7 +8611,7 @@ public class LogicalPlanOptimizerTests extends AbstractLogicalPlanOptimizerTests
         var query = """
             from test
             | where
-                 ((knn(dense_vector, [0, 1, 2], 10) and integer > 10) or keyword == "test") or ((short < 5) and (double > 5.0))
+                 ((knn(dense_vector, [0, 1, 2]) and integer > 10) or keyword == "test") or ((short < 5) and (double > 5.0))
             """;
         var optimized = planTypes(query);
 
@@ -8626,7 +8626,7 @@ public class LogicalPlanOptimizerTests extends AbstractLogicalPlanOptimizerTests
     }
 
     public void testMultipleKnnQueriesInPrefilters() {
-        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
+        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled());
 
         /*
             and
@@ -8639,7 +8639,7 @@ public class LogicalPlanOptimizerTests extends AbstractLogicalPlanOptimizerTests
          */
         var query = """
             from test
-            | where ((knn(dense_vector, [0, 1, 2], 10) or integer > 10) and ((keyword == "test") or knn(dense_vector, [4, 5, 6], 10)))
+            | where ((knn(dense_vector, [0, 1, 2]) or integer > 10) and ((keyword == "test") or knn(dense_vector, [4, 5, 6])))
             """;
         var optimized = planTypes(query);
 
@@ -8668,6 +8668,156 @@ public class LogicalPlanOptimizerTests extends AbstractLogicalPlanOptimizerTests
         assertTrue(secondKnnFilters.contains(firstOr.right()));
     }
 
+    public void testKnnImplicitLimit() {
+        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled());
+
+        var query = """
+            from test
+            | where knn(dense_vector, [0, 1, 2])
+            """;
+        var optimized = planTypes(query);
+
+        var limit = as(optimized, Limit.class);
+        var filter = as(limit.child(), Filter.class);
+        var knn = as(filter.condition(), Knn.class);
+        assertThat(knn.k(), equalTo(1000));
+    }
+
+    public void testKnnWithLimit() {
+        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled());
+
+        var query = """
+            from test
+            | where knn(dense_vector, [0, 1, 2])
+            | limit 10
+            """;
+        var optimized = planTypes(query);
+
+        var limit = as(optimized, Limit.class);
+        var filter = as(limit.child(), Filter.class);
+        var knn = as(filter.condition(), Knn.class);
+        assertThat(knn.k(), equalTo(10));
+    }
+
+    public void testKnnWithTopN() {
+        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled());
+
+        var query = """
+            from test metadata _score
+            | where knn(dense_vector, [0, 1, 2])
+            | sort _score desc
+            | limit 10
+            """;
+        var optimized = planTypes(query);
+
+        var topN = as(optimized, TopN.class);
+        var filter = as(topN.child(), Filter.class);
+        var knn = as(filter.condition(), Knn.class);
+        assertThat(knn.k(), equalTo(10));
+    }
+
+    public void testKnnWithMultipleLimitsAfterTopN() {
+        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled());
+
+        var query = """
+            from test metadata _score
+            | where knn(dense_vector, [0, 1, 2])
+            | limit 20
+            | sort _score desc
+            | limit 10
+            """;
+        var optimized = planTypes(query);
+
+        var topN = as(optimized, TopN.class);
+        assertThat(topN.limit().fold(FoldContext.small()), equalTo(10));
+        var limit = as(topN.child(), Limit.class);
+        var filter = as(limit.child(), Filter.class);
+        var knn = as(filter.condition(), Knn.class);
+        assertThat(knn.k(), equalTo(20));
+    }
+
+    public void testKnnWithMultipleLimitsCombined() {
+        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled());
+
+        var query = """
+            from test metadata _score
+            | where knn(dense_vector, [0, 1, 2])
+            | limit 20
+            | limit 10
+            """;
+        var optimized = planTypes(query);
+
+        var limit = as(optimized, Limit.class);
+        assertThat(limit.limit().fold(FoldContext.small()), equalTo(10));
+        var filter = as(limit.child(), Filter.class);
+        var knn = as(filter.condition(), Knn.class);
+        assertThat(knn.k(), equalTo(10));
+    }
+
+    public void testKnnWithMultipleClauses() {
+        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled());
+
+        var query = """
+            from test metadata _score
+            | where knn(dense_vector, [0, 1, 2]) and match(keyword, "test")
+            | where knn(dense_vector, [1, 2, 3])
+            | sort _score
+            | limit 10
+            """;
+        var optimized = planTypes(query);
+
+        var topN = as(optimized, TopN.class);
+        assertThat(topN.limit().fold(FoldContext.small()), equalTo(10));
+        var filter = as(topN.child(), Filter.class);
+        var firstAnd = as(filter.condition(), And.class);
+        var fistKnn = as(firstAnd.right(), Knn.class);
+        assertThat(((Literal) fistKnn.query()).value(), is(List.of(1.0f, 2.0f, 3.0f)));
+        var secondAnd = as(firstAnd.left(), And.class);
+        var secondKnn = as(secondAnd.left(), Knn.class);
+        assertThat(((Literal) secondKnn.query()).value(), is(List.of(0.0f, 1.0f, 2.0f)));
+    }
+
+    public void testKnnWithStats() {
+        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled());
+
+        assertThat(
+            typesError("from test | where knn(dense_vector, [0, 1, 2]) | stats c = count(*)"),
+            containsString("Knn function must be used with a LIMIT clause")
+        );
+    }
+
+    public void testKnnWithRerankAmdTopN() {
+        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled());
+
+        assertThat(typesError("""
+            from test metadata _score
+            | where knn(dense_vector, [0, 1, 2])
+            | rerank "some text" on text with { "inference_id" : "reranking-inference-id" }
+            | sort _score desc
+            | limit 10
+            """), containsString("Knn function must be used with a LIMIT clause"));
+    }
+
+    public void testKnnWithRerankAmdLimit() {
+        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V4.isEnabled());
+
+        var query = """
+            from test metadata _score
+            | where knn(dense_vector, [0, 1, 2])
+            | rerank "some text" on text with { "inference_id" : "reranking-inference-id" }
+            | limit 100
+            """;
+
+        var optimized = planTypes(query);
+
+        var rerank = as(optimized, Rerank.class);
+        var limit = as(rerank.child(), Limit.class);
+        assertThat(limit.limit().fold(FoldContext.small()), equalTo(100));
+        var filter = as(limit.child(), Filter.class);
+        var knn = as(filter.condition(), Knn.class);
+        assertThat(knn.k(), equalTo(100));
+    }
+
     private LogicalPlanOptimizer getCustomRulesLogicalPlanOptimizer(List<RuleExecutor.Batch<LogicalPlan>> batches) {
         LogicalOptimizerContext context = new LogicalOptimizerContext(EsqlTestUtils.TEST_CFG, FoldContext.small());
         LogicalPlanOptimizer customOptimizer = new LogicalPlanOptimizer(context) {