Browse Source

ESQL - KNN function uses prefilters when pushed down to Lucene (#131004)

Carlos Delgado 3 months ago
parent
commit
6ffe27d030
23 changed files with 842 additions and 69 deletions
  1. 1 1
      docs/reference/query-languages/esql/kibana/docs/functions/knn.md
  2. 66 23
      x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec
  3. 25 3
      x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java
  4. 1 1
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
  5. 5 1
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java
  6. 2 3
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java
  7. 2 1
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Kql.java
  8. 2 1
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Match.java
  9. 2 1
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MatchPhrase.java
  10. 2 1
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MultiMatch.java
  11. 2 1
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/QueryString.java
  12. 2 1
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Term.java
  13. 61 13
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java
  14. 1 1
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java
  15. 2 0
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java
  16. 130 0
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownConjunctionsToKnnPrefilters.java
  17. 16 3
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java
  18. 3 2
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java
  19. 1 1
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java
  20. 9 9
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java
  21. 1 1
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java
  22. 335 1
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java
  23. 171 0
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java

+ 1 - 1
docs/reference/query-languages/esql/kibana/docs/functions/knn.md

@@ -1,4 +1,4 @@
-% This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.
+% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.
 
 ### KNN
 Finds the k nearest vectors to a query vector, as measured by a similarity metric. knn function finds nearest vectors through approximate search on indexed dense_vectors.

+ 66 - 23
x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec

@@ -3,7 +3,7 @@
 # top-n query at the shard level 
 
 knnSearch
-required_capability: knn_function_v2
+required_capability: knn_function_v3
 
 // tag::knn-function[]
 from colors metadata _score 
@@ -30,7 +30,7 @@ chartreuse | [127.0, 255.0, 0.0]
 ;
 
 knnSearchWithSimilarityOption
-required_capability: knn_function_v2
+required_capability: knn_function_v3
 
 from colors metadata _score 
 | where knn(rgb_vector, [255,192,203], 140, {"similarity": 40})
@@ -46,14 +46,13 @@ wheat      | [245.0, 222.0, 179.0]
 ;
 
 knnHybridSearch
-required_capability: knn_function_v2
+required_capability: knn_function_v3
 
 from colors metadata _score 
-| where match(color, "blue") or knn(rgb_vector, [65,105,225], 140)
+| where match(color, "blue") or knn(rgb_vector, [65,105,225], 10)
 | where primary == true
 | sort _score desc, color asc
 | keep color, rgb_vector
-| limit 10
 ;
 
 color:text | rgb_vector:dense_vector
@@ -68,21 +67,45 @@ red        | [255.0, 0.0, 0.0]
 yellow     | [255.0, 255.0, 0.0]
 ;
 
-knnWithMultipleFunctions
-required_capability: knn_function_v2
+knnWithPrefilter
+required_capability: knn_function_v3
 
 from colors metadata _score
-| where knn(rgb_vector, [128,128,0], 140) and match(color, "olive") 
+| where knn(rgb_vector, [128,128,0], 10) and (match(color, "olive") or match(color, "green")) 
 | sort _score desc, color asc
 | keep color, rgb_vector
 ;
 
 color:text       | rgb_vector:dense_vector
 olive            | [128.0, 128.0, 0.0]
+green            | [0.0, 128.0, 0.0]
+;
+
+knnWithNegatedPrefilter
+required_capability: knn_function_v3
+
+from colors metadata _score
+| where knn(rgb_vector, [128,128,0], 10) and not (match(color, "olive") or match(color, "chocolate")) 
+| sort _score desc, color asc
+| keep color, rgb_vector
+| LIMIT 10
+;
+
+color:text | rgb_vector:dense_vector
+sienna     | [160.0, 82.0, 45.0]
+peru       | [205.0, 133.0, 63.0]
+golden rod | [218.0, 165.0, 32.0]
+brown      | [165.0, 42.0, 42.0]
+firebrick  | [178.0, 34.0, 34.0]
+chartreuse | [127.0, 255.0, 0.0]
+gray       | [128.0, 128.0, 128.0]
+green      | [0.0, 128.0, 0.0]
+maroon     | [128.0, 0.0, 0.0]
+orange     | [255.0, 165.0, 0.0]
 ;
 
 knnAfterKeep
-required_capability: knn_function_v2
+required_capability: knn_function_v3
 
 from colors metadata _score
 | keep rgb_vector, color, _score 
@@ -101,7 +124,7 @@ rgb_vector:dense_vector
 ;
 
 knnAfterDrop
-required_capability: knn_function_v2
+required_capability: knn_function_v3
 
 from colors metadata _score
 | drop primary
@@ -120,7 +143,7 @@ lime           | [0.0, 255.0, 0.0]
 ;
 
 knnAfterEval
-required_capability: knn_function_v2
+required_capability: knn_function_v3
 
 from colors metadata _score
 | eval composed_name = locate(color, " ") > 0 
@@ -139,14 +162,12 @@ golden rod | true
 ;
 
 knnWithConjunction
-required_capability: knn_function_v2
+required_capability: knn_function_v3
 
-# TODO We need kNN prefiltering here so we get more candidates that pass the filter
 from colors metadata _score 
-| where knn(rgb_vector, [255,255,238], 140) and hex_code like "#FFF*" 
+| where knn(rgb_vector, [255,255,238], 10) and hex_code like "#FFF*" 
 | sort _score desc, color asc
 | keep color, hex_code, rgb_vector
-| limit 10
 ;
 
 color:text    | hex_code:keyword | rgb_vector:dense_vector
@@ -160,11 +181,10 @@ yellow        | #FFFF00          | [255.0, 255.0, 0.0]
 ;
 
 knnWithDisjunctionAndFiltersConjunction
-required_capability: knn_function_v2
+required_capability: knn_function_v3
 
-# TODO We need kNN prefiltering here so we get more candidates that pass the filter
 from colors metadata _score 
-| where (knn(rgb_vector, [0,255,255], 140) or knn(rgb_vector, [128, 0, 255], 140)) and primary == true 
+| where (knn(rgb_vector, [0,255,255], 140) or knn(rgb_vector, [128, 0, 255], 10)) and primary == true 
 | keep color, rgb_vector, _score
 | sort _score desc, color asc
 | drop _score
@@ -183,8 +203,31 @@ red        | [255.0, 0.0, 0.0]
 yellow     | [255.0, 255.0, 0.0]
 ;
 
+knnWithNegationsAndFiltersConjunction
+required_capability: knn_function_v3
+
+from colors metadata _score 
+| where (knn(rgb_vector, [0,255,255], 140) and not(primary == true and match(color, "blue"))) 
+| sort _score desc, color asc
+| keep color, rgb_vector
+| limit 10
+;
+
+color:text  | rgb_vector:dense_vector
+cyan        | [0.0, 255.0, 255.0]
+turquoise   | [64.0, 224.0, 208.0]
+aqua marine | [127.0, 255.0, 212.0]
+teal        | [0.0, 128.0, 128.0]
+silver      | [192.0, 192.0, 192.0]
+gray        | [128.0, 128.0, 128.0]
+gainsboro   | [220.0, 220.0, 220.0]
+thistle     | [216.0, 191.0, 216.0]
+lavender    | [230.0, 230.0, 250.0]
+azure       | [240.0, 255.0, 255.0]
+;
+
 knnWithNonPushableConjunction
-required_capability: knn_function_v2
+required_capability: knn_function_v3
 
 from colors metadata _score
 | eval composed_name = locate(color, " ") > 0 
@@ -208,7 +251,7 @@ maroon     | false
 ;
 
 testKnnWithNonPushableDisjunctions
-required_capability: knn_function_v2
+required_capability: knn_function_v3
 
 from colors metadata _score 
 | where knn(rgb_vector, [128,128,0], 140, {"similarity": 30}) or length(color) > 10 
@@ -224,7 +267,7 @@ papaya whip
 ;
 
 testKnnWithNonPushableDisjunctionsOnComplexExpressions
-required_capability: knn_function_v2
+required_capability: knn_function_v3
 
 from colors metadata _score 
 | where (knn(rgb_vector, [128,128,0], 140, {"similarity": 70}) and length(color) < 10) or (knn(rgb_vector, [128,0,128], 140, {"similarity": 60}) and primary == false) 
@@ -239,7 +282,7 @@ indigo       | false
 ;
 
 testKnnInStatsNonPushable
-required_capability: knn_function_v2
+required_capability: knn_function_v3
 
 from colors 
 | where length(color) < 10 
@@ -251,7 +294,7 @@ c: long
 ;
 
 testKnnInStatsWithGrouping
-required_capability: knn_function_v2
+required_capability: knn_function_v3
 required_capability: full_text_functions_in_stats_where
 
 from colors 

+ 25 - 3
x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java

@@ -114,6 +114,29 @@ public class KnnFunctionIT extends AbstractEsqlIntegTestCase {
         }
     }
 
+    public void testKnnWithPrefilters() {
+        float[] queryVector = new float[numDims];
+        Arrays.fill(queryVector, 1.0f);
+
+        // We retrieve 5 from knn, but must be prefiltered with id > 5 or no result will be returned as it would be post-filtered
+        var query = String.format(Locale.ROOT, """
+            FROM test METADATA _score
+            | WHERE knn(vector, %s, 5) AND id > 5
+            | KEEP id, floats, _score, vector
+            | SORT _score DESC
+            | LIMIT 5
+            """, Arrays.toString(queryVector));
+
+        try (var resp = run(query)) {
+            assertColumnNames(resp.columns(), List.of("id", "floats", "_score", "vector"));
+            assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector"));
+
+            List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
+            // K = 5, 1 more for every id > 10
+            assertEquals(5, valuesList.size());
+        }
+    }
+
     public void testKnnWithLookupJoin() {
         float[] queryVector = new float[numDims];
         Arrays.fill(queryVector, 1.0f);
@@ -136,7 +159,7 @@ public class KnnFunctionIT extends AbstractEsqlIntegTestCase {
 
     @Before
     public void setup() throws IOException {
-        assumeTrue("Needs KNN support", EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled());
+        assumeTrue("Needs KNN support", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
 
         var indexName = "test";
         var client = client().admin().indices();
@@ -163,7 +186,7 @@ public class KnnFunctionIT extends AbstractEsqlIntegTestCase {
         var createRequest = client.prepareCreate(indexName).setMapping(mapping).setSettings(settingsBuilder.build());
         assertAcked(createRequest);
 
-        numDocs = randomIntBetween(10, 20);
+        numDocs = randomIntBetween(15, 25);
         numDims = randomIntBetween(3, 10);
         IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs];
         float value = 0.0f;
@@ -202,6 +225,5 @@ public class KnnFunctionIT extends AbstractEsqlIntegTestCase {
 
         var createRequest = client.prepareCreate(lookupIndexName).setMapping(mapping).setSettings(settingsBuilder.build());
         assertAcked(createRequest);
-
     }
 }

+ 1 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java

@@ -1213,7 +1213,7 @@ public class EsqlCapabilities {
         /**
          * Support knn function
          */
-        KNN_FUNCTION_V2(Build.current().isSnapshot()),
+        KNN_FUNCTION_V3(Build.current().isSnapshot()),
 
         /**
          * Support for the LIKE operator with a list of wildcards.

+ 5 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java

@@ -481,7 +481,7 @@ public class EsqlFunctionRegistry {
                 def(FirstOverTime.class, uni(FirstOverTime::new), "first_over_time"),
                 def(Score.class, uni(Score::new), Score.NAME),
                 def(Term.class, bi(Term::new), "term"),
-                def(Knn.class, Knn::new, "knn"),
+                def(Knn.class, quad(Knn::new), "knn"),
                 def(StGeohash.class, StGeohash::new, "st_geohash"),
                 def(StGeohashToLong.class, StGeohashToLong::new, "st_geohash_to_long"),
                 def(StGeohashToString.class, StGeohashToString::new, "st_geohash_to_string"),
@@ -1208,4 +1208,8 @@ public class EsqlFunctionRegistry {
         return function;
     }
 
+    private static <T extends Function> QuaternaryBuilder<T> quad(QuaternaryBuilder<T> function) {
+        return function;
+    }
+
 }

+ 2 - 3
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java

@@ -166,20 +166,19 @@ public abstract class FullTextFunction extends Function
 
     @Override
     public Translatable translatable(LucenePushdownPredicates pushdownPredicates) {
-        // In isolation, full text functions are pushable to source. We check if there are no disjunctions in Or conditions
         return Translatable.YES;
     }
 
     @Override
     public Query asQuery(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) {
-        return queryBuilder != null ? new TranslationAwareExpressionQuery(source(), queryBuilder) : translate(handler);
+        return queryBuilder != null ? new TranslationAwareExpressionQuery(source(), queryBuilder) : translate(pushdownPredicates, handler);
     }
 
     public QueryBuilder queryBuilder() {
         return queryBuilder;
     }
 
-    protected abstract Query translate(TranslatorHandler handler);
+    protected abstract Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler);
 
     public abstract Expression replaceQueryBuilder(QueryBuilder queryBuilder);
 

+ 2 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Kql.java

@@ -22,6 +22,7 @@ import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecyc
 import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
 import org.elasticsearch.xpack.esql.expression.function.Param;
 import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
+import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
 import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
 import org.elasticsearch.xpack.esql.querydsl.query.KqlQuery;
 
@@ -93,7 +94,7 @@ public class Kql extends FullTextFunction {
     }
 
     @Override
-    protected Query translate(TranslatorHandler handler) {
+    protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) {
         return new KqlQuery(source(), Objects.toString(queryAsObject()));
     }
 

+ 2 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Match.java

@@ -35,6 +35,7 @@ import org.elasticsearch.xpack.esql.expression.function.MapParam;
 import org.elasticsearch.xpack.esql.expression.function.OptionalArgument;
 import org.elasticsearch.xpack.esql.expression.function.Param;
 import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
+import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
 import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
 import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
 import org.elasticsearch.xpack.esql.querydsl.query.MatchQuery;
@@ -423,7 +424,7 @@ public class Match extends FullTextFunction implements OptionalArgument, PostAna
     }
 
     @Override
-    protected Query translate(TranslatorHandler handler) {
+    protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) {
         var fieldAttribute = fieldAsFieldAttribute();
         Check.notNull(fieldAttribute, "Match must have a field attribute as the first argument");
         String fieldName = getNameFromFieldAttribute(fieldAttribute);

+ 2 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MatchPhrase.java

@@ -32,6 +32,7 @@ import org.elasticsearch.xpack.esql.expression.function.MapParam;
 import org.elasticsearch.xpack.esql.expression.function.OptionalArgument;
 import org.elasticsearch.xpack.esql.expression.function.Param;
 import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
+import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
 import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
 import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
 import org.elasticsearch.xpack.esql.querydsl.query.MatchPhraseQuery;
@@ -278,7 +279,7 @@ public class MatchPhrase extends FullTextFunction implements OptionalArgument, P
     }
 
     @Override
-    protected Query translate(TranslatorHandler handler) {
+    protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) {
         var fieldAttribute = fieldAsFieldAttribute();
         Check.notNull(fieldAttribute, "MatchPhrase must have a field attribute as the first argument");
         String fieldName = getNameFromFieldAttribute(fieldAttribute);

+ 2 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MultiMatch.java

@@ -31,6 +31,7 @@ import org.elasticsearch.xpack.esql.expression.function.MapParam;
 import org.elasticsearch.xpack.esql.expression.function.OptionalArgument;
 import org.elasticsearch.xpack.esql.expression.function.Param;
 import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
+import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
 import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
 import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
 import org.elasticsearch.xpack.esql.querydsl.query.MultiMatchQuery;
@@ -335,7 +336,7 @@ public class MultiMatch extends FullTextFunction implements OptionalArgument, Po
     }
 
     @Override
-    protected Query translate(TranslatorHandler handler) {
+    protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) {
         Map<String, Float> fieldsWithBoost = new HashMap<>();
         for (Expression field : fields) {
             var fieldAttribute = Match.fieldAsFieldAttribute(field);

+ 2 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/QueryString.java

@@ -28,6 +28,7 @@ import org.elasticsearch.xpack.esql.expression.function.MapParam;
 import org.elasticsearch.xpack.esql.expression.function.OptionalArgument;
 import org.elasticsearch.xpack.esql.expression.function.Param;
 import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
+import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
 import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
 
 import java.io.IOException;
@@ -345,7 +346,7 @@ public class QueryString extends FullTextFunction implements OptionalArgument {
     }
 
     @Override
-    protected Query translate(TranslatorHandler handler) {
+    protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) {
         return new QueryStringQuery(source(), Objects.toString(queryAsObject()), Map.of(), queryStringOptions());
     }
 

+ 2 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Term.java

@@ -27,6 +27,7 @@ import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecyc
 import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
 import org.elasticsearch.xpack.esql.expression.function.Param;
 import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
+import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
 import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
 import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
 
@@ -130,7 +131,7 @@ public class Term extends FullTextFunction implements PostAnalysisPlanVerificati
     }
 
     @Override
-    protected Query translate(TranslatorHandler handler) {
+    protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) {
         // Uses a term query that contributes to scoring
         return new TermQuery(source(), ((FieldAttribute) field()).name(), queryAsObject(), false, true);
     }

+ 61 - 13
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java

@@ -12,6 +12,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.xpack.esql.capabilities.PostAnalysisPlanVerificationAware;
+import org.elasticsearch.xpack.esql.capabilities.TranslationAware;
 import org.elasticsearch.xpack.esql.common.Failures;
 import org.elasticsearch.xpack.esql.core.InvalidArgumentException;
 import org.elasticsearch.xpack.esql.core.expression.Expression;
@@ -33,6 +34,7 @@ import org.elasticsearch.xpack.esql.expression.function.Param;
 import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextFunction;
 import org.elasticsearch.xpack.esql.expression.function.fulltext.Match;
 import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
+import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
 import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
 import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
 import org.elasticsearch.xpack.esql.querydsl.query.KnnQuery;
@@ -70,6 +72,8 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
     // k is not serialized as it's already included in the query builder on the rewrite step before being sent to data nodes
     private final transient Expression k;
     private final Expression options;
+    // Expressions to be used as prefilters in knn query
+    private final List<Expression> filterExpressions;
 
     public static final Map<String, DataType> ALLOWED_OPTIONS = Map.ofEntries(
         entry(NUM_CANDS_FIELD.getPreferredName(), INTEGER),
@@ -139,14 +143,23 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
             optional = true
         ) Expression options
     ) {
-        this(source, field, query, k, options, null);
+        this(source, field, query, k, options, null, List.of());
     }
 
-    private Knn(Source source, Expression field, Expression query, Expression k, Expression options, QueryBuilder queryBuilder) {
+    public Knn(
+        Source source,
+        Expression field,
+        Expression query,
+        Expression k,
+        Expression options,
+        QueryBuilder queryBuilder,
+        List<Expression> filterExpressions
+    ) {
         super(source, query, expressionList(field, query, k, options), queryBuilder);
         this.field = field;
         this.k = k;
         this.options = options;
+        this.filterExpressions = filterExpressions;
     }
 
     private static List<Expression> expressionList(Expression field, Expression query, Expression k, Expression options) {
@@ -174,6 +187,10 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
         return options;
     }
 
+    public List<Expression> filterExpressions() {
+        return filterExpressions;
+    }
+
     @Override
     public DataType dataType() {
         return DataType.BOOLEAN;
@@ -236,10 +253,26 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
     }
 
     @Override
-    protected Query translate(TranslatorHandler handler) {
+    public Expression replaceQueryBuilder(QueryBuilder queryBuilder) {
+        return new Knn(source(), field(), query(), k(), options(), queryBuilder, filterExpressions());
+    }
+
+    @Override
+    public Translatable translatable(LucenePushdownPredicates pushdownPredicates) {
+        Translatable translatable = super.translatable(pushdownPredicates);
+        // We need to check whether filter expressions are translatable as well
+        for (Expression filterExpression : filterExpressions()) {
+            translatable = translatable.merge(TranslationAware.translatable(filterExpression, pushdownPredicates));
+        }
+
+        return translatable;
+    }
+
+    @Override
+    protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) {
         var fieldAttribute = Match.fieldAsFieldAttribute(field());
 
-        Check.notNull(fieldAttribute, "Match must have a field attribute as the first argument");
+        Check.notNull(fieldAttribute, "Knn must have a field attribute as the first argument");
         String fieldName = getNameFromFieldAttribute(fieldAttribute);
         @SuppressWarnings("unchecked")
         List<Number> queryFolded = (List<Number>) query().fold(FoldContext.small() /* TODO remove me */);
@@ -252,12 +285,23 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
         Map<String, Object> opts = queryOptions();
         opts.put(K_FIELD.getPreferredName(), kValue);
 
-        return new KnnQuery(source(), fieldName, queryAsFloats, opts);
+        List<QueryBuilder> filterQueries = new ArrayList<>();
+        for (Expression filterExpression : filterExpressions()) {
+            if (filterExpression instanceof TranslationAware translationAware) {
+                // We can only translate filter expressions that are translatable. In case any is not translatable,
+                // Knn won't be pushed down as it will not be translatable so it's safe not to translate all filters and check them
+                // when creating an evaluator for the non-pushed down query
+                if (translationAware.translatable(pushdownPredicates) == Translatable.YES) {
+                    filterQueries.add(handler.asQuery(pushdownPredicates, filterExpression).toQueryBuilder());
+                }
+            }
+        }
+
+        return new KnnQuery(source(), fieldName, queryAsFloats, opts, filterQueries);
     }
 
-    @Override
-    public Expression replaceQueryBuilder(QueryBuilder queryBuilder) {
-        return new Knn(source(), field(), query(), k(), options(), queryBuilder);
+    public Expression withFilters(List<Expression> filterExpressions) {
+        return new Knn(source(), field(), query(), k(), options(), queryBuilder(), filterExpressions);
     }
 
     private Map<String, Object> queryOptions() throws InvalidArgumentException {
@@ -284,13 +328,14 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
             newChildren.get(1),
             newChildren.get(2),
             newChildren.size() > 3 ? newChildren.get(3) : null,
-            queryBuilder()
+            queryBuilder(),
+            filterExpressions()
         );
     }
 
     @Override
     protected NodeInfo<? extends Expression> info() {
-        return NodeInfo.create(this, Knn::new, field(), query(), k(), options());
+        return NodeInfo.create(this, Knn::new, field(), query(), k(), options(), queryBuilder(), filterExpressions());
     }
 
     @Override
@@ -303,7 +348,8 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
         Expression field = in.readNamedWriteable(Expression.class);
         Expression query = in.readNamedWriteable(Expression.class);
         QueryBuilder queryBuilder = in.readOptionalNamedWriteable(QueryBuilder.class);
-        return new Knn(source, field, query, null, null, queryBuilder);
+        List<Expression> filterExpressions = in.readNamedWriteableCollectionAsList(Expression.class);
+        return new Knn(source, field, query, null, null, queryBuilder, filterExpressions);
     }
 
     @Override
@@ -312,6 +358,7 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
         out.writeNamedWriteable(field());
         out.writeNamedWriteable(query());
         out.writeOptionalNamedWriteable(queryBuilder());
+        out.writeNamedWriteableCollection(filterExpressions());
     }
 
     @Override
@@ -322,12 +369,13 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
         Knn knn = (Knn) o;
         return Objects.equals(field(), knn.field())
             && Objects.equals(query(), knn.query())
-            && Objects.equals(queryBuilder(), knn.queryBuilder());
+            && Objects.equals(queryBuilder(), knn.queryBuilder())
+            && Objects.equals(filterExpressions(), knn.filterExpressions());
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(field(), query(), queryBuilder());
+        return Objects.hash(field(), query(), queryBuilder(), filterExpressions());
     }
 
 }

+ 1 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java

@@ -27,7 +27,7 @@ public final class VectorWritables {
     public static List<NamedWriteableRegistry.Entry> getNamedWritables() {
         List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
 
-        if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
+        if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) {
             entries.add(Knn.ENTRY);
         }
         if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {

+ 2 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java

@@ -38,6 +38,7 @@ import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownAndCombineFi
 import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownAndCombineLimits;
 import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownAndCombineOrderBy;
 import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownAndCombineSample;
+import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownConjunctionsToKnnPrefilters;
 import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownEnrich;
 import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownEval;
 import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownInferencePlan;
@@ -192,6 +193,7 @@ public class LogicalPlanOptimizer extends ParameterizedRuleExecutor<LogicalPlan,
             new PruneLiteralsInOrderBy(),
             new PushDownAndCombineLimits(),
             new PushDownAndCombineFilters(),
+            new PushDownConjunctionsToKnnPrefilters(),
             new PushDownAndCombineSample(),
             new PushDownInferencePlan(),
             new PushDownEval(),

+ 130 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownConjunctionsToKnnPrefilters.java

@@ -0,0 +1,130 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.optimizer.rules.logical;
+
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
+import org.elasticsearch.xpack.esql.expression.predicate.logical.And;
+import org.elasticsearch.xpack.esql.expression.predicate.logical.BinaryLogic;
+import org.elasticsearch.xpack.esql.plan.logical.Filter;
+import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Objects;
+import java.util.Stack;
+
+/**
+ * Rewrites an expression tree to push down conjunctions in the prefilter of {@link Knn} functions.
+ * knn functions won't contain other knn functions as a prefilter, to avoid circular dependencies.
+ *  Given an expression tree like {@code (A OR B) AND (C AND knn())} this rule will rewrite it to
+ *     {@code (A OR B) AND (C AND knn(filterExpressions = [(A OR B), C]))}
+*/
+public class PushDownConjunctionsToKnnPrefilters extends OptimizerRules.OptimizerRule<Filter> {
+
+    @Override
+    protected LogicalPlan rule(Filter filter) {
+        Stack<Expression> filters = new Stack<>();
+        Expression condition = filter.condition();
+        Expression newCondition = pushConjunctionsToKnn(condition, filters, null);
+
+        return condition.equals(newCondition) ? filter : filter.with(newCondition);
+    }
+
+    /**
+     * Updates knn function prefilters. This method processes conjunctions so knn functions on one side of the conjunction receive
+     * the other side of the conjunction as a prefilter
+     *
+     * @param expression expression to process recursively
+     * @param filters current filters to apply to the expression. They contain expressions on the other side of the traversed conjunctions
+     * @param addedFilter a new filter to add to the list of filters for the processing
+     * @return the updated expression, or the original expression if it doesn't need to be updated
+     */
+    private static Expression pushConjunctionsToKnn(Expression expression, Stack<Expression> filters, Expression addedFilter) {
+        if (addedFilter != null) {
+            filters.push(addedFilter);
+        }
+        Expression result = switch (expression) {
+            case And and:
+                // Traverse both sides of the And, using the other side as the added filter
+                Expression newLeft = pushConjunctionsToKnn(and.left(), filters, and.right());
+                Expression newRight = pushConjunctionsToKnn(and.right(), filters, and.left());
+                if (newLeft.equals(and.left()) && newRight.equals(and.right())) {
+                    yield and;
+                }
+                yield and.replaceChildrenSameSize(List.of(newLeft, newRight));
+            case Knn knn:
+                // We don't want knn expressions to have other knn expressions as a prefilter to avoid circular dependencies
+                List<Expression> newFilters = filters.stream()
+                    .map(PushDownConjunctionsToKnnPrefilters::removeKnn)
+                    .filter(Objects::nonNull)
+                    .toList();
+                if (newFilters.equals(knn.filterExpressions())) {
+                    yield knn;
+                }
+                yield knn.withFilters(newFilters);
+            default:
+                List<Expression> children = expression.children();
+                boolean childrenChanged = false;
+
+                // This copies transformChildren algorithm to avoid unnecessary changes
+                List<Expression> transformedChildren = null;
+
+                for (int i = 0, s = children.size(); i < s; i++) {
+                    Expression child = children.get(i);
+                    Expression next = pushConjunctionsToKnn(child, filters, null);
+                    if (child.equals(next) == false) {
+                        // lazy copy + replacement in place
+                        if (childrenChanged == false) {
+                            childrenChanged = true;
+                            transformedChildren = new ArrayList<>(children);
+                        }
+                        transformedChildren.set(i, next);
+                    }
+                }
+
+                yield (childrenChanged ? expression.replaceChildrenSameSize(transformedChildren) : expression);
+        };
+
+        if (addedFilter != null) {
+            filters.pop();
+        }
+
+        return result;
+    }
+
+    /**
+     * Removes knn functions from the expression tree
+     * @param expression expression to process
+     * @return expression without knn functions, or null if the expression is a knn function
+     */
+    private static Expression removeKnn(Expression expression) {
+        if (expression.children().isEmpty()) {
+            return expression;
+        }
+        if (expression instanceof Knn) {
+            return null;
+        }
+
+        List<Expression> filteredChildren = expression.children()
+            .stream()
+            .map(PushDownConjunctionsToKnnPrefilters::removeKnn)
+            .filter(Objects::nonNull)
+            .toList();
+        if (filteredChildren.equals(expression.children())) {
+            return expression;
+        } else if (filteredChildren.isEmpty()) {
+            return null;
+        } else if (expression instanceof BinaryLogic && filteredChildren.size() == 1) {
+            // Simplify an AND / OR expression to a single child
+            return filteredChildren.getFirst();
+        } else {
+            return expression.replaceChildrenSameSize(filteredChildren);
+        }
+    }
+}

+ 16 - 3
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java

@@ -13,7 +13,9 @@ import org.elasticsearch.search.vectors.RescoreVectorBuilder;
 import org.elasticsearch.xpack.esql.core.querydsl.query.Query;
 import org.elasticsearch.xpack.esql.core.tree.Source;
 
+import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 
@@ -27,15 +29,17 @@ public class KnnQuery extends Query {
     private final String field;
     private final float[] query;
     private final Map<String, Object> options;
+    private final List<QueryBuilder> filterQueries;
 
     public static final String RESCORE_OVERSAMPLE_FIELD = "rescore_oversample";
 
-    public KnnQuery(Source source, String field, float[] query, Map<String, Object> options) {
+    public KnnQuery(Source source, String field, float[] query, Map<String, Object> options, List<QueryBuilder> filterQueries) {
         super(source);
         assert options != null;
         this.field = field;
         this.query = query;
         this.options = options;
+        this.filterQueries = new ArrayList<>(filterQueries);
     }
 
     @Override
@@ -50,6 +54,9 @@ public class KnnQuery extends Query {
         Float vectorSimilarity = (Float) options.get(VECTOR_SIMILARITY_FIELD.getPreferredName());
 
         KnnVectorQueryBuilder queryBuilder = new KnnVectorQueryBuilder(field, query, k, numCands, rescoreVectorBuilder, vectorSimilarity);
+        for (QueryBuilder filter : filterQueries) {
+            queryBuilder.addFilterQuery(filter);
+        }
         Number boost = (Number) options.get(BOOST_FIELD.getPreferredName());
         if (boost != null) {
             queryBuilder.boost(boost.floatValue());
@@ -66,15 +73,17 @@ public class KnnQuery extends Query {
     public boolean equals(Object o) {
         if (super.equals(o) == false) return false;
 
+        if (o == null || getClass() != o.getClass()) return false;
         KnnQuery knnQuery = (KnnQuery) o;
         return Objects.equals(field, knnQuery.field)
             && Objects.deepEquals(query, knnQuery.query)
-            && Objects.equals(options, knnQuery.options);
+            && Objects.equals(options, knnQuery.options)
+            && Objects.equals(filterQueries, knnQuery.filterQueries);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(super.hashCode(), field, Arrays.hashCode(query), options);
+        return Objects.hash(super.hashCode(), field, Arrays.hashCode(query), options, filterQueries);
     }
 
     @Override
@@ -86,4 +95,8 @@ public class KnnQuery extends Query {
     public boolean containsPlan() {
         return false;
     }
+
+    public List<QueryBuilder> filterQueries() {
+        return filterQueries;
+    }
 }

+ 3 - 2
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java

@@ -212,9 +212,10 @@ public class EsqlSession {
         analyzedPlan(parsed, executionInfo, request.filter(), new EsqlCCSUtils.CssPartialErrorsActionListener(executionInfo, listener) {
             @Override
             public void onResponse(LogicalPlan analyzedPlan) {
+                LogicalPlan optimizedPlan = optimizedPlan(analyzedPlan);
                 preMapper.preMapper(
-                    analyzedPlan,
-                    listener.delegateFailureAndWrap((l, p) -> executeOptimizedPlan(request, executionInfo, planRunner, optimizedPlan(p), l))
+                    optimizedPlan,
+                    listener.delegateFailureAndWrap((l, p) -> executeOptimizedPlan(request, executionInfo, planRunner, p, l))
                 );
             }
         });

+ 1 - 1
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java

@@ -303,7 +303,7 @@ public class CsvTests extends ESTestCase {
             );
             assumeFalse(
                 "can't use KNN function in csv tests",
-                testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.KNN_FUNCTION_V2.capabilityName())
+                testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.KNN_FUNCTION_V3.capabilityName())
             );
             assumeFalse(
                 "lookup join disabled for csv tests",

+ 9 - 9
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java

@@ -1237,7 +1237,7 @@ public class VerifierTests extends ESTestCase {
             checkFieldBasedWithNonIndexedColumn("Term", "term(text, \"cat\")", "function");
             checkFieldBasedFunctionNotAllowedAfterCommands("Term", "function", "term(title, \"Meditation\")");
         }
-        if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
+        if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) {
             checkFieldBasedFunctionNotAllowedAfterCommands("KNN", "function", "knn(vector, [1, 2, 3], 10)");
         }
     }
@@ -1370,7 +1370,7 @@ public class VerifierTests extends ESTestCase {
         if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) {
             checkFullTextFunctionsOnlyAllowedInWhere("MultiMatch", "multi_match(\"Meditation\", title, body)", "function");
         }
-        if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
+        if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) {
             checkFullTextFunctionsOnlyAllowedInWhere("KNN", "knn(vector, [0, 1, 2], 10)", "function");
         }
 
@@ -1417,7 +1417,7 @@ public class VerifierTests extends ESTestCase {
         if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) {
             checkWithFullTextFunctionsDisjunctions("term(title, \"Meditation\")");
         }
-        if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
+        if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) {
             checkWithFullTextFunctionsDisjunctions("knn(vector, [1, 2, 3], 10)");
         }
     }
@@ -1482,7 +1482,7 @@ public class VerifierTests extends ESTestCase {
         if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) {
             checkFullTextFunctionsWithNonBooleanFunctions("Term", "term(title, \"Meditation\")", "function");
         }
-        if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
+        if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) {
             checkFullTextFunctionsWithNonBooleanFunctions("KNN", "knn(vector, [1, 2, 3], 10)", "function");
         }
     }
@@ -1553,7 +1553,7 @@ public class VerifierTests extends ESTestCase {
         if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) {
             testFullTextFunctionTargetsExistingField("term(fist_name, \"Meditation\")");
         }
-        if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
+        if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) {
             testFullTextFunctionTargetsExistingField("knn(vector, [0, 1, 2], 10)");
         }
     }
@@ -2081,7 +2081,7 @@ public class VerifierTests extends ESTestCase {
         if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) {
             checkOptionDataTypes(MultiMatch.OPTIONS, "FROM test | WHERE MULTI_MATCH(\"Jean\", title, body, {\"%s\": %s})");
         }
-        if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
+        if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) {
             checkOptionDataTypes(Knn.ALLOWED_OPTIONS, "FROM test | WHERE KNN(vector, [0.1, 0.2, 0.3], 10, {\"%s\": %s})");
         }
     }
@@ -2169,7 +2169,7 @@ public class VerifierTests extends ESTestCase {
             checkFullTextFunctionNullArgs("term(null, \"query\")", "first");
             checkFullTextFunctionNullArgs("term(title, null)", "second");
         }
-        if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
+        if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) {
             checkFullTextFunctionNullArgs("knn(null, [0, 1, 2], 10)", "first");
             checkFullTextFunctionNullArgs("knn(vector, null, 10)", "second");
             checkFullTextFunctionNullArgs("knn(vector, [0, 1, 2], null)", "third");
@@ -2195,7 +2195,7 @@ public class VerifierTests extends ESTestCase {
         if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) {
             checkFullTextFunctionsConstantArg("term(title, tags)", "second");
         }
-        if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
+        if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) {
             checkFullTextFunctionsConstantArg("knn(vector, vector, 10)", "second");
             checkFullTextFunctionsConstantArg("knn(vector, [0, 1, 2], category)", "third");
         }
@@ -2226,7 +2226,7 @@ public class VerifierTests extends ESTestCase {
         if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) {
             checkFullTextFunctionsInStats("multi_match(\"Meditation\", title, body)");
         }
-        if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
+        if (EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled()) {
             checkFullTextFunctionsInStats("knn(vector, [0, 1, 2], 10)");
         }
     }

+ 1 - 1
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java

@@ -51,7 +51,7 @@ public class KnnTests extends AbstractFunctionTestCase {
 
     @Before
     public void checkCapability() {
-        assumeTrue("KNN is not enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled());
+        assumeTrue("KNN is not enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
     }
 
     private static List<TestCaseSupplier> testCaseSuppliers() {

+ 335 - 1
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java

@@ -63,6 +63,8 @@ import org.elasticsearch.xpack.esql.expression.function.fulltext.Kql;
 import org.elasticsearch.xpack.esql.expression.function.fulltext.Match;
 import org.elasticsearch.xpack.esql.expression.function.fulltext.MatchOperator;
 import org.elasticsearch.xpack.esql.expression.function.fulltext.QueryString;
+import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
+import org.elasticsearch.xpack.esql.expression.predicate.logical.And;
 import org.elasticsearch.xpack.esql.expression.predicate.logical.Or;
 import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan;
 import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual;
@@ -108,6 +110,7 @@ import org.junit.Before;
 
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.List;
 import java.util.Locale;
@@ -1371,7 +1374,7 @@ public class LocalPhysicalPlanOptimizerTests extends MapperServiceTestCase {
 
     public void testKnnOptionsPushDown() {
         assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled());
-        assumeTrue("knn capability not available", EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled());
+        assumeTrue("knn capability not available", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
 
         String query = """
             from test
@@ -1836,6 +1839,308 @@ public class LocalPhysicalPlanOptimizerTests extends MapperServiceTestCase {
         aggExec.forEachDown(EsQueryExec.class, esQueryExec -> { assertNull(esQueryExec.query()); });
     }
 
+    public void testKnnPrefilters() {
+        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
+
+        String query = """
+            from test
+            | where knn(dense_vector, [0, 1, 2], 10) and integer > 10
+            """;
+        var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json"));
+
+        var limit = as(plan, LimitExec.class);
+        var exchange = as(limit.child(), ExchangeExec.class);
+        var project = as(exchange.child(), ProjectExec.class);
+        var field = as(project.child(), FieldExtractExec.class);
+        var queryExec = as(field.child(), EsQueryExec.class);
+        QueryBuilder expectedFilterQueryBuilder = wrapWithSingleQuery(
+            query,
+            unscore(rangeQuery("integer").gt(10)),
+            "integer",
+            new Source(2, 45, "integer > 10")
+        );
+        KnnVectorQueryBuilder expectedKnnQueryBuilder = new KnnVectorQueryBuilder(
+            "dense_vector",
+            new float[] { 0, 1, 2 },
+            10,
+            null,
+            null,
+            null
+        ).addFilterQuery(expectedFilterQueryBuilder);
+        var expectedQuery = boolQuery().must(expectedKnnQueryBuilder).must(expectedFilterQueryBuilder);
+        assertEquals(expectedQuery.toString(), queryExec.query().toString());
+    }
+
+    public void testKnnPrefiltersWithMultipleFilters() {
+        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
+
+        String query = """
+            from test
+            | where knn(dense_vector, [0, 1, 2], 10)
+            | where integer > 10
+            | where keyword == "test"
+            """;
+        var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json"));
+
+        var limit = as(plan, LimitExec.class);
+        var exchange = as(limit.child(), ExchangeExec.class);
+        var project = as(exchange.child(), ProjectExec.class);
+        var field = as(project.child(), FieldExtractExec.class);
+        var queryExec = as(field.child(), EsQueryExec.class);
+        var integerFilter = wrapWithSingleQuery(query, unscore(rangeQuery("integer").gt(10)), "integer", new Source(3, 8, "integer > 10"));
+        var keywordFilter = wrapWithSingleQuery(
+            query,
+            unscore(termQuery("keyword", "test")),
+            "keyword",
+            new Source(4, 8, "keyword == \"test\"")
+        );
+        QueryBuilder expectedFilterQueryBuilder = boolQuery().must(integerFilter).must(keywordFilter);
+        KnnVectorQueryBuilder expectedKnnQueryBuilder = new KnnVectorQueryBuilder(
+            "dense_vector",
+            new float[] { 0, 1, 2 },
+            10,
+            null,
+            null,
+            null
+        ).addFilterQuery(expectedFilterQueryBuilder);
+        var expectedQuery = boolQuery().must(expectedKnnQueryBuilder).must(integerFilter).must(keywordFilter);
+        assertEquals(expectedQuery.toString(), queryExec.query().toString());
+    }
+
+    public void testPushDownConjunctionsToKnnPrefilter() {
+        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
+
+        String query = """
+            from test
+            | where knn(dense_vector, [0, 1, 2], 10) and integer > 10
+            """;
+        var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json"));
+
+        var limit = as(plan, LimitExec.class);
+        var exchange = as(limit.child(), ExchangeExec.class);
+        var project = as(exchange.child(), ProjectExec.class);
+        var field = as(project.child(), FieldExtractExec.class);
+        var queryExec = as(field.child(), EsQueryExec.class);
+
+        // The filter condition should be pushed down to both the KNN query and the main query
+        QueryBuilder expectedFilterQueryBuilder = wrapWithSingleQuery(
+            query,
+            unscore(rangeQuery("integer").gt(10)),
+            "integer",
+            new Source(2, 45, "integer > 10")
+        );
+
+        KnnVectorQueryBuilder expectedKnnQueryBuilder = new KnnVectorQueryBuilder(
+            "dense_vector",
+            new float[] { 0, 1, 2 },
+            10,
+            null,
+            null,
+            null
+        ).addFilterQuery(expectedFilterQueryBuilder);
+
+        var expectedQuery = boolQuery().must(expectedKnnQueryBuilder).must(expectedFilterQueryBuilder);
+
+        assertEquals(expectedQuery.toString(), queryExec.query().toString());
+    }
+
+    public void testPushDownNegatedConjunctionsToKnnPrefilter() {
+        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
+
+        String query = """
+            from test
+            | where knn(dense_vector, [0, 1, 2], 10) and NOT integer > 10
+            """;
+        var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json"));
+
+        var limit = as(plan, LimitExec.class);
+        var exchange = as(limit.child(), ExchangeExec.class);
+        var project = as(exchange.child(), ProjectExec.class);
+        var field = as(project.child(), FieldExtractExec.class);
+        var queryExec = as(field.child(), EsQueryExec.class);
+
+        // The filter condition should be pushed down to both the KNN query and the main query
+        QueryBuilder expectedFilterQueryBuilder = wrapWithSingleQuery(
+            query,
+            unscore(boolQuery().mustNot(unscore(rangeQuery("integer").gt(10)))),
+            "integer",
+            new Source(2, 45, "NOT integer > 10")
+        );
+
+        KnnVectorQueryBuilder expectedKnnQueryBuilder = new KnnVectorQueryBuilder(
+            "dense_vector",
+            new float[] { 0, 1, 2 },
+            10,
+            null,
+            null,
+            null
+        ).addFilterQuery(expectedFilterQueryBuilder);
+
+        var expectedQuery = boolQuery().must(expectedKnnQueryBuilder).must(expectedFilterQueryBuilder);
+
+        assertEquals(expectedQuery.toString(), queryExec.query().toString());
+    }
+
+    public void testNotPushDownDisjunctionsToKnnPrefilter() {
+        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
+
+        String query = """
+            from test
+            | where knn(dense_vector, [0, 1, 2], 10) or integer > 10
+            """;
+        var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json"));
+
+        var limit = as(plan, LimitExec.class);
+        var exchange = as(limit.child(), ExchangeExec.class);
+        var project = as(exchange.child(), ProjectExec.class);
+        var field = as(project.child(), FieldExtractExec.class);
+        var queryExec = as(field.child(), EsQueryExec.class);
+
+        // The disjunction should not be pushed down to the KNN query
+        KnnVectorQueryBuilder knnQueryBuilder = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 10, null, null, null);
+        QueryBuilder rangeQueryBuilder = wrapWithSingleQuery(
+            query,
+            unscore(rangeQuery("integer").gt(10)),
+            "integer",
+            new Source(2, 44, "integer > 10")
+        );
+
+        var expectedQuery = boolQuery().should(knnQueryBuilder).should(rangeQueryBuilder);
+
+        assertEquals(expectedQuery.toString(), queryExec.query().toString());
+    }
+
+    public void testNotPushDownKnnWithNonPushablePrefilters() {
+        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
+
+        String query = """
+            from test
+            | where ((knn(dense_vector, [0, 1, 2], 10) AND integer > 10) and ((keyword == "test") or length(text) > 10))
+            """;
+        var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json"));
+
+        var limit = as(plan, LimitExec.class);
+        var exchange = as(limit.child(), ExchangeExec.class);
+        var project = as(exchange.child(), ProjectExec.class);
+        var field = as(project.child(), FieldExtractExec.class);
+        var secondLimit = as(field.child(), LimitExec.class);
+        var filter = as(secondLimit.child(), FilterExec.class);
+        var and = as(filter.condition(), And.class);
+        var knn = as(and.left(), Knn.class);
+        assertEquals("(keyword == \"test\") or length(text) > 10", knn.filterExpressions().get(0).toString());
+        assertEquals("integer > 10", knn.filterExpressions().get(1).toString());
+
+        var fieldExtract = as(filter.child(), FieldExtractExec.class);
+        var queryExec = as(fieldExtract.child(), EsQueryExec.class);
+
+        // The query should only contain the pushable condition
+        QueryBuilder integerGtQuery = wrapWithSingleQuery(
+            query,
+            unscore(rangeQuery("integer").gt(10)),
+            "integer",
+            new Source(2, 47, "integer > 10")
+        );
+
+        assertEquals(integerGtQuery.toString(), queryExec.query().toString());
+    }
+
+    public void testPushDownComplexNegationsToKnnPrefilter() {
+        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
+
+        String query = """
+            from test
+            | where ((knn(dense_vector, [0, 1, 2], 10) or NOT integer > 10)
+              and NOT ((keyword == "test") or knn(dense_vector, [4, 5, 6], 10)))
+            """;
+        var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json"));
+
+        var limit = as(plan, LimitExec.class);
+        var exchange = as(limit.child(), ExchangeExec.class);
+        var project = as(exchange.child(), ProjectExec.class);
+        var fieldExtract = as(project.child(), FieldExtractExec.class);
+        var queryExec = as(fieldExtract.child(), EsQueryExec.class);
+
+        QueryBuilder notKeywordQuery = wrapWithSingleQuery(
+            query,
+            unscore(boolQuery().mustNot(unscore(termQuery("keyword", "test")))),
+            "keyword",
+            new Source(3, 12, "keyword == \"test\"")
+        );
+        QueryBuilder notKeywordFilter = wrapWithSingleQuery(
+            query,
+            unscore(boolQuery().mustNot(unscore(termQuery("keyword", "test")))),
+            "keyword",
+            new Source(3, 6, "NOT ((keyword == \"test\") or knn(dense_vector, [4, 5, 6], 10))")
+        );
+
+        QueryBuilder notIntegerGt10 = wrapWithSingleQuery(
+            query,
+            unscore(boolQuery().mustNot(unscore(rangeQuery("integer").gt(10)))),
+            "integer",
+            new Source(2, 46, "NOT integer > 10")
+        );
+
+        KnnVectorQueryBuilder firstKnn = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 10, null, null, null);
+        KnnVectorQueryBuilder secondKnn = new KnnVectorQueryBuilder("dense_vector", new float[] { 4, 5, 6 }, 10, null, null, null);
+
+        firstKnn.addFilterQuery(notKeywordFilter);
+        secondKnn.addFilterQuery(notIntegerGt10);
+
+        // Build the main boolean query structure
+        BoolQueryBuilder expectedQuery = boolQuery().must(notKeywordQuery)  // NOT (keyword == "test")
+            .must(unscore(boolQuery().mustNot(secondKnn)))
+            .must(boolQuery().should(firstKnn).should(notIntegerGt10));
+
+        assertEquals(expectedQuery.toString(), queryExec.query().toString());
+    }
+
+    public void testMultipleKnnQueriesInPrefilters() {
+        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
+
+        String query = """
+            from test
+            | where ((knn(dense_vector, [0, 1, 2], 10) or integer > 10) and ((keyword == "test") or knn(dense_vector, [4, 5, 6], 10)))
+            """;
+        var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json"));
+
+        var limit = as(plan, LimitExec.class);
+        var exchange = as(limit.child(), ExchangeExec.class);
+        var project = as(exchange.child(), ProjectExec.class);
+        var field = as(project.child(), FieldExtractExec.class);
+        var queryExec = as(field.child(), EsQueryExec.class);
+
+        KnnVectorQueryBuilder firstKnnQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 10, null, null, null);
+        // Integer range query (right side of first OR)
+        QueryBuilder integerRangeQuery = wrapWithSingleQuery(
+            query,
+            unscore(rangeQuery("integer").gt(10)),
+            "integer",
+            new Source(2, 46, "integer > 10")
+        );
+
+        // Second KNN query (right side of second OR)
+        KnnVectorQueryBuilder secondKnnQuery = new KnnVectorQueryBuilder("dense_vector", new float[] { 4, 5, 6 }, 10, null, null, null);
+
+        // Keyword term query (left side of second OR)
+        QueryBuilder keywordQuery = wrapWithSingleQuery(
+            query,
+            unscore(termQuery("keyword", "test")),
+            "keyword",
+            new Source(2, 66, "keyword == \"test\"")
+        );
+
+        // First OR (knn1 OR integer > 10)
+        var firstOr = boolQuery().should(firstKnnQuery).should(integerRangeQuery);
+        // Second OR (keyword == "test" OR knn2)
+        var secondOr = boolQuery().should(keywordQuery).should(secondKnnQuery);
+        firstKnnQuery.addFilterQuery(keywordQuery);
+        secondKnnQuery.addFilterQuery(integerRangeQuery);
+
+        // Top-level AND combining both ORs
+        var expectedQuery = boolQuery().must(firstOr).must(secondOr);
+        assertEquals(expectedQuery.toString(), queryExec.query().toString());
+    }
+
     public void testParallelizeTimeSeriesPlan() {
         assumeTrue("requires snapshot builds", Build.current().isSnapshot());
         var query = "TS k8s | STATS max(rate(network.total_bytes_in)) BY bucket(@timestamp, 1h)";
@@ -2234,4 +2539,33 @@ public class LocalPhysicalPlanOptimizerTests extends MapperServiceTestCase {
             return "qstr(\"" + fieldName() + ": " + queryString() + "\")";
         }
     }
+
+    private class KnnFunctionTestCase extends FullTextFunctionTestCase {
+
+        final int k;
+
+        KnnFunctionTestCase() {
+            super(Knn.class, "dense_vector", randomVector());
+            k = randomIntBetween(1, 10);
+        }
+
+        private static Object randomVector() {
+            int numDims = randomIntBetween(10, 20);
+            float[] vector = new float[numDims];
+            for (int i = 0; i < numDims; i++) {
+                vector[i] = randomFloat();
+            }
+            return vector;
+        }
+
+        @Override
+        public QueryBuilder queryBuilder() {
+            return new KnnVectorQueryBuilder(fieldName(), (float[]) queryString(), k, null, null, null);
+        }
+
+        @Override
+        public String esqlQuery() {
+            return "knn(" + fieldName() + ", " + Arrays.toString(((float[]) queryString())) + ", " + k + ")";
+        }
+    }
 }

+ 171 - 0
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java

@@ -74,6 +74,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMin;
 import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvSum;
 import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce;
 import org.elasticsearch.xpack.esql.expression.function.scalar.string.Concat;
+import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
 import org.elasticsearch.xpack.esql.expression.predicate.logical.And;
 import org.elasticsearch.xpack.esql.expression.predicate.logical.Not;
 import org.elasticsearch.xpack.esql.expression.predicate.logical.Or;
@@ -7855,4 +7856,174 @@ public class LogicalPlanOptimizerTests extends AbstractLogicalPlanOptimizerTests
         var topN = as(changePoint.child(), TopN.class);
         var source = as(topN.child(), EsRelation.class);
     }
+
+    public void testPushDownConjunctionsToKnnPrefilter() {
+        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
+
+        var query = """
+            from test
+            | where knn(dense_vector, [0, 1, 2], 10) and integer > 10
+            """;
+        var optimized = planTypes(query);
+
+        var limit = as(optimized, Limit.class);
+        var filter = as(limit.child(), Filter.class);
+        var and = as(filter.condition(), And.class);
+        var knn = as(and.left(), Knn.class);
+        List<Expression> filterExpressions = knn.filterExpressions();
+        assertThat(filterExpressions.size(), equalTo(1));
+        var prefilter = as(filterExpressions.get(0), GreaterThan.class);
+        assertThat(and.right(), equalTo(prefilter));
+        var esRelation = as(filter.child(), EsRelation.class);
+    }
+
+    public void testPushDownMultipleFiltersToKnnPrefilter() {
+        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
+
+        var query = """
+            from test
+            | where knn(dense_vector, [0, 1, 2], 10)
+            | where integer > 10
+            | where keyword == "test"
+            """;
+        var optimized = planTypes(query);
+
+        var limit = as(optimized, Limit.class);
+        var filter = as(limit.child(), Filter.class);
+        var firstAnd = as(filter.condition(), And.class);
+        var knn = as(firstAnd.left(), Knn.class);
+        var prefilterAnd = as(firstAnd.right(), And.class);
+        as(prefilterAnd.left(), GreaterThan.class);
+        as(prefilterAnd.right(), Equals.class);
+        List<Expression> filterExpressions = knn.filterExpressions();
+        assertThat(filterExpressions.size(), equalTo(1));
+        assertThat(prefilterAnd, equalTo(filterExpressions.get(0)));
+    }
+
+    public void testNotPushDownDisjunctionsToKnnPrefilter() {
+        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
+
+        var query = """
+            from test
+            | where knn(dense_vector, [0, 1, 2], 10) or integer > 10
+            """;
+        var optimized = planTypes(query);
+
+        var limit = as(optimized, Limit.class);
+        var filter = as(limit.child(), Filter.class);
+        var or = as(filter.condition(), Or.class);
+        var knn = as(or.left(), Knn.class);
+        List<Expression> filterExpressions = knn.filterExpressions();
+        assertThat(filterExpressions.size(), equalTo(0));
+    }
+
+    public void testPushDownConjunctionsAndNotDisjunctionsToKnnPrefilter() {
+        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
+
+        /*
+            and
+                and
+                    or
+                        knn(dense_vector, [0, 1, 2], 10)
+                        integer > 10
+                    keyword == "test"
+                 or
+                     short < 5
+                     double > 5.0
+         */
+        // Both conjunctions are pushed down to knn prefilters, disjunctions are not
+        var query = """
+            from test
+            | where
+                 ((knn(dense_vector, [0, 1, 2], 10) or integer > 10) and keyword == "test") and ((short < 5) or (double > 5.0))
+            """;
+        var optimized = planTypes(query);
+
+        var limit = as(optimized, Limit.class);
+        var filter = as(limit.child(), Filter.class);
+        var and = as(filter.condition(), And.class);
+        var leftAnd = as(and.left(), And.class);
+        var rightOr = as(and.right(), Or.class);
+        var leftOr = as(leftAnd.left(), Or.class);
+        var knn = as(leftOr.left(), Knn.class);
+        var rightOrPrefilter = as(knn.filterExpressions().get(0), Or.class);
+        assertThat(rightOr, equalTo(rightOrPrefilter));
+        var leftAndPrefilter = as(knn.filterExpressions().get(1), Equals.class);
+        assertThat(leftAnd.right(), equalTo(leftAndPrefilter));
+    }
+
+    public void testMorePushDownConjunctionsAndNotDisjunctionsToKnnPrefilter() {
+        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
+
+        /*
+            or
+                or
+                    and
+                        knn(dense_vector, [0, 1, 2], 10)
+                        integer > 10
+                    keyword == "test"
+                 and
+                     short < 5
+                     double > 5.0
+         */
+        // Just the conjunction is pushed down to knn prefilters, disjunctions are not
+        var query = """
+            from test
+            | where
+                 ((knn(dense_vector, [0, 1, 2], 10) and integer > 10) or keyword == "test") or ((short < 5) and (double > 5.0))
+            """;
+        var optimized = planTypes(query);
+
+        var limit = as(optimized, Limit.class);
+        var filter = as(limit.child(), Filter.class);
+        var or = as(filter.condition(), Or.class);
+        var leftOr = as(or.left(), Or.class);
+        var leftAnd = as(leftOr.left(), And.class);
+        var knn = as(leftAnd.left(), Knn.class);
+        var rightAndPrefilter = as(knn.filterExpressions().get(0), GreaterThan.class);
+        assertThat(leftAnd.right(), equalTo(rightAndPrefilter));
+    }
+
+    public void testMultipleKnnQueriesInPrefilters() {
+        assumeTrue("knn must be enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
+
+        /*
+            and
+                or
+                    knn(dense_vector, [0, 1, 2], 10)
+                    integer > 10
+                or
+                    keyword == "test"
+                    knn(dense_vector, [4, 5, 6], 10)
+         */
+        var query = """
+            from test
+            | where ((knn(dense_vector, [0, 1, 2], 10) or integer > 10) and ((keyword == "test") or knn(dense_vector, [4, 5, 6], 10)))
+            """;
+        var optimized = planTypes(query);
+
+        var limit = as(optimized, Limit.class);
+        var filter = as(limit.child(), Filter.class);
+        var and = as(filter.condition(), And.class);
+
+        // First OR (knn1 OR integer > 10)
+        var firstOr = as(and.left(), Or.class);
+        var firstKnn = as(firstOr.left(), Knn.class);
+        var integerGt = as(firstOr.right(), GreaterThan.class);
+
+        // Second OR (keyword == "test" OR knn2)
+        var secondOr = as(and.right(), Or.class);
+        as(secondOr.left(), Equals.class);
+        var secondKnn = as(secondOr.right(), Knn.class);
+
+        // First KNN should have the second OR as its filter
+        List<Expression> firstKnnFilters = firstKnn.filterExpressions();
+        assertThat(firstKnnFilters.size(), equalTo(1));
+        assertTrue(firstKnnFilters.contains(secondOr.left()));
+
+        // Second KNN should have the first OR as its filter
+        List<Expression> secondKnnFilters = secondKnn.filterExpressions();
+        assertThat(secondKnnFilters.size(), equalTo(1));
+        assertTrue(secondKnnFilters.contains(firstOr.right()));
+    }
 }