Browse Source

Speed up (filtered) KNN queries for flat vector fields (#130251)

For dense vector fields using the `flat` index, we already know a brute-force search will be used—so there’s no need to go through the codec’s approximate KNN logic. This change skips that step and builds the brute-force query directly, making things faster and simpler.

I tested this on a setup with **10 million random vectors**, each with **1596 dimensions** and **17,500 partitions**, using the `random_vector` track.
The results:

### Performance Comparison

| Metric            | Before    | After      | Change    |
| ----------------- | --------- | ---------- | --------- |
| **Throughput**    | 221 ops/s | 2762 ops/s | 🟢 +1149% |
| **Latency (p50)** | 29.2 ms   | 1.6 ms     | 🔻 -94.4% |
| **Latency (p99)** | 81.6 ms   | 3.5 ms     | 🔻 -95.7% |

Filtered KNN queries on flat vectors are now over 10x faster on my laptop!
Jim Ferenczi 3 months ago
parent
commit
2142915fcb

+ 5 - 0
docs/changelog/130251.yaml

@@ -0,0 +1,5 @@
+pr: 130251
+summary: Speed up (filtered) KNN queries for flat vector fields
+area: Vector Search
+type: enhancement
+issues: []

+ 1 - 1
qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java

@@ -310,7 +310,7 @@ class KnnSearcher {
         }
         if (overSamplingFactor > 1f) {
             // oversample the topK results to get more candidates for the final result
-            knnQuery = new RescoreKnnVectorQuery(VECTOR_FIELD, vector, similarityFunction, this.topK, knnQuery);
+            knnQuery = RescoreKnnVectorQuery.fromInnerQuery(VECTOR_FIELD, vector, similarityFunction, this.topK, topK, knnQuery);
         }
         QueryProfiler profiler = new QueryProfiler();
         TopDocs docs = searcher.search(knnQuery, this.topK);

+ 104 - 9
server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

@@ -30,6 +30,8 @@ import org.apache.lucene.index.SegmentReadState;
 import org.apache.lucene.index.SegmentWriteState;
 import org.apache.lucene.index.VectorEncoding;
 import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.search.BooleanClause;
+import org.apache.lucene.search.BooleanQuery;
 import org.apache.lucene.search.FieldExistsQuery;
 import org.apache.lucene.search.Query;
 import org.apache.lucene.search.join.BitSetProducer;
@@ -77,6 +79,7 @@ import org.elasticsearch.search.aggregations.support.CoreValuesSourceType;
 import org.elasticsearch.search.lookup.Source;
 import org.elasticsearch.search.vectors.DenseVectorQuery;
 import org.elasticsearch.search.vectors.DiversifyingChildrenIVFKnnFloatVectorQuery;
+import org.elasticsearch.search.vectors.DiversifyingParentBlockQuery;
 import org.elasticsearch.search.vectors.ESDiversifyingChildrenByteKnnVectorQuery;
 import org.elasticsearch.search.vectors.ESDiversifyingChildrenFloatKnnVectorQuery;
 import org.elasticsearch.search.vectors.ESKnnByteVectorQuery;
@@ -1391,6 +1394,18 @@ public class DenseVectorFieldMapper extends FieldMapper {
         public final int hashCode() {
             return Objects.hash(type, doHashCode());
         }
+
+        /**
+         * Indicates whether the underlying vector search is performed using a flat (exhaustive) approach.
+         * <p>
+         * When {@code true}, it means the search does not use any approximate nearest neighbor (ANN)
+         * acceleration structures such as HNSW or IVF. Instead, it performs a brute-force comparison
+         * against all candidate vectors. This information can be used by higher-level components
+         * to decide whether additional acceleration or optimization is necessary.
+         *
+         * @return {@code true} if the vector search is flat (exhaustive), {@code false} if it uses ANN structures
+         */
+        abstract boolean isFlat();
     }
 
     abstract static class QuantizedIndexOptions extends DenseVectorIndexOptions {
@@ -1762,6 +1777,11 @@ public class DenseVectorFieldMapper extends FieldMapper {
             return Objects.hash(confidenceInterval, rescoreVector);
         }
 
+        @Override
+        boolean isFlat() {
+            return true;
+        }
+
         @Override
         public boolean updatableTo(DenseVectorIndexOptions update) {
             return update.type.equals(this.type)
@@ -1810,6 +1830,11 @@ public class DenseVectorFieldMapper extends FieldMapper {
         public int doHashCode() {
             return Objects.hash(type);
         }
+
+        @Override
+        boolean isFlat() {
+            return true;
+        }
     }
 
     public static class Int4HnswIndexOptions extends QuantizedIndexOptions {
@@ -1860,6 +1885,11 @@ public class DenseVectorFieldMapper extends FieldMapper {
             return Objects.hash(m, efConstruction, confidenceInterval, rescoreVector);
         }
 
+        @Override
+        boolean isFlat() {
+            return false;
+        }
+
         @Override
         public String toString() {
             return "{type="
@@ -1931,6 +1961,11 @@ public class DenseVectorFieldMapper extends FieldMapper {
             return Objects.hash(confidenceInterval, rescoreVector);
         }
 
+        @Override
+        boolean isFlat() {
+            return true;
+        }
+
         @Override
         public String toString() {
             return "{type=" + type + ", confidence_interval=" + confidenceInterval + ", rescore_vector=" + rescoreVector + "}";
@@ -1999,6 +2034,11 @@ public class DenseVectorFieldMapper extends FieldMapper {
             return Objects.hash(m, efConstruction, confidenceInterval, rescoreVector);
         }
 
+        @Override
+        boolean isFlat() {
+            return false;
+        }
+
         @Override
         public String toString() {
             return "{type="
@@ -2088,6 +2128,11 @@ public class DenseVectorFieldMapper extends FieldMapper {
             return Objects.hash(m, efConstruction);
         }
 
+        @Override
+        boolean isFlat() {
+            return false;
+        }
+
         @Override
         public String toString() {
             return "{type=" + type + ", m=" + m + ", ef_construction=" + efConstruction + "}";
@@ -2126,6 +2171,11 @@ public class DenseVectorFieldMapper extends FieldMapper {
             return Objects.hash(m, efConstruction, rescoreVector);
         }
 
+        @Override
+        boolean isFlat() {
+            return false;
+        }
+
         @Override
         public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
             builder.startObject();
@@ -2179,6 +2229,11 @@ public class DenseVectorFieldMapper extends FieldMapper {
             return CLASS_NAME_HASH;
         }
 
+        @Override
+        boolean isFlat() {
+            return true;
+        }
+
         @Override
         public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
             builder.startObject();
@@ -2237,6 +2292,11 @@ public class DenseVectorFieldMapper extends FieldMapper {
             return Objects.hash(clusterSize, defaultNProbe, rescoreVector);
         }
 
+        @Override
+        boolean isFlat() {
+            return false;
+        }
+
         @Override
         public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
             builder.startObject();
@@ -2485,9 +2545,21 @@ public class DenseVectorFieldMapper extends FieldMapper {
             KnnSearchStrategy searchStrategy
         ) {
             elementType.checkDimensions(dims, queryVector.length);
-            Query knnQuery = parentFilter != null
-                ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy)
-                : new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy);
+            Query knnQuery;
+            if (indexOptions != null && indexOptions.isFlat()) {
+                var exactKnnQuery = parentFilter != null
+                    ? new DiversifyingParentBlockQuery(parentFilter, createExactKnnBitQuery(queryVector))
+                    : createExactKnnBitQuery(queryVector);
+                knnQuery = filter == null
+                    ? exactKnnQuery
+                    : new BooleanQuery.Builder().add(exactKnnQuery, BooleanClause.Occur.SHOULD)
+                        .add(filter, BooleanClause.Occur.FILTER)
+                        .build();
+            } else {
+                knnQuery = parentFilter != null
+                    ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy)
+                    : new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy);
+            }
             if (similarityThreshold != null) {
                 knnQuery = new VectorSimilarityQuery(
                     knnQuery,
@@ -2513,9 +2585,22 @@ public class DenseVectorFieldMapper extends FieldMapper {
                 float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
                 elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude);
             }
-            Query knnQuery = parentFilter != null
-                ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy)
-                : new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy);
+
+            Query knnQuery;
+            if (indexOptions != null && indexOptions.isFlat()) {
+                var exactKnnQuery = parentFilter != null
+                    ? new DiversifyingParentBlockQuery(parentFilter, createExactKnnByteQuery(queryVector))
+                    : createExactKnnByteQuery(queryVector);
+                knnQuery = filter == null
+                    ? exactKnnQuery
+                    : new BooleanQuery.Builder().add(exactKnnQuery, BooleanClause.Occur.SHOULD)
+                        .add(filter, BooleanClause.Occur.FILTER)
+                        .build();
+            } else {
+                knnQuery = parentFilter != null
+                    ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy)
+                    : new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy);
+            }
             if (similarityThreshold != null) {
                 knnQuery = new VectorSimilarityQuery(
                     knnQuery,
@@ -2568,7 +2653,16 @@ public class DenseVectorFieldMapper extends FieldMapper {
                 numCands = Math.max(adjustedK, numCands);
             }
             Query knnQuery;
-            if (indexOptions instanceof BBQIVFIndexOptions bbqIndexOptions) {
+            if (indexOptions != null && indexOptions.isFlat()) {
+                var exactKnnQuery = parentFilter != null
+                    ? new DiversifyingParentBlockQuery(parentFilter, createExactKnnFloatQuery(queryVector))
+                    : createExactKnnFloatQuery(queryVector);
+                knnQuery = filter == null
+                    ? exactKnnQuery
+                    : new BooleanQuery.Builder().add(exactKnnQuery, BooleanClause.Occur.SHOULD)
+                        .add(filter, BooleanClause.Occur.FILTER)
+                        .build();
+            } else if (indexOptions instanceof BBQIVFIndexOptions bbqIndexOptions) {
                 knnQuery = parentFilter != null
                     ? new DiversifyingChildrenIVFKnnFloatVectorQuery(
                         name(),
@@ -2594,11 +2688,12 @@ public class DenseVectorFieldMapper extends FieldMapper {
                     : new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter, knnSearchStrategy);
             }
             if (rescore) {
-                knnQuery = new RescoreKnnVectorQuery(
+                knnQuery = RescoreKnnVectorQuery.fromInnerQuery(
                     name(),
                     queryVector,
                     similarity.vectorSimilarityFunction(indexVersionCreated, ElementType.FLOAT),
                     k,
+                    adjustedK,
                     knnQuery
                 );
             }
@@ -2624,7 +2719,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
             return elementType;
         }
 
-        public IndexOptions getIndexOptions() {
+        public DenseVectorIndexOptions getIndexOptions() {
             return indexOptions;
         }
 

+ 1 - 0
server/src/main/java/org/elasticsearch/search/vectors/DenseVectorQuery.java

@@ -207,4 +207,5 @@ public abstract class DenseVectorQuery extends Query {
             return iterator.docID();
         }
     }
+
 }

+ 195 - 0
server/src/main/java/org/elasticsearch/search/vectors/DiversifyingParentBlockQuery.java

@@ -0,0 +1,195 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.search.vectors;
+
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.search.BooleanClause;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.Explanation;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.QueryVisitor;
+import org.apache.lucene.search.ScoreMode;
+import org.apache.lucene.search.Scorer;
+import org.apache.lucene.search.ScorerSupplier;
+import org.apache.lucene.search.Weight;
+import org.apache.lucene.search.join.BitSetProducer;
+
+import java.io.IOException;
+import java.util.Objects;
+
+/**
+ * A Lucene query that selects the highest-scoring child document for each parent block.
+ * <p>
+ * Children are scored using the {@code innerQuery}, and for each parent (as defined by the
+ * {@code parentFilter}), the single best-scoring child is returned.
+ */
+public class DiversifyingParentBlockQuery extends Query {
+    private final BitSetProducer parentFilter;
+    private final Query innerQuery;
+
+    public DiversifyingParentBlockQuery(BitSetProducer parentFilter, Query innerQuery) {
+        this.parentFilter = Objects.requireNonNull(parentFilter);
+        this.innerQuery = Objects.requireNonNull(innerQuery);
+    }
+
+    @Override
+    public Query rewrite(IndexSearcher indexSearcher) throws IOException {
+        Query rewritten = innerQuery.rewrite(indexSearcher);
+        if (rewritten != innerQuery) {
+            return new DiversifyingParentBlockQuery(parentFilter, rewritten);
+        }
+        return this;
+    }
+
+    @Override
+    public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
+        Weight innerWeight = innerQuery.createWeight(searcher, scoreMode, boost);
+        return new DiversifyingParentBlockWeight(this, innerWeight, parentFilter);
+    }
+
+    @Override
+    public String toString(String field) {
+        return "DiversifyingBlockQuery(inner=" + innerQuery.toString(field) + ")";
+    }
+
+    @Override
+    public void visit(QueryVisitor visitor) {
+        innerQuery.visit(visitor.getSubVisitor(BooleanClause.Occur.MUST, this));
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        DiversifyingParentBlockQuery that = (DiversifyingParentBlockQuery) o;
+        return Objects.equals(innerQuery, that.innerQuery) && parentFilter == that.parentFilter;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(innerQuery, parentFilter);
+    }
+
+    private static class DiversifyingParentBlockWeight extends Weight {
+        private final Weight innerWeight;
+        private final BitSetProducer parentFilter;
+
+        DiversifyingParentBlockWeight(Query query, Weight innerWeight, BitSetProducer parentFilter) {
+            super(query);
+            this.innerWeight = innerWeight;
+            this.parentFilter = parentFilter;
+        }
+
+        @Override
+        public Explanation explain(LeafReaderContext context, int doc) throws IOException {
+            return innerWeight.explain(context, doc);
+        }
+
+        @Override
+        public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
+            var innerSupplier = innerWeight.scorerSupplier(context);
+            var parentBits = parentFilter.getBitSet(context);
+            if (parentBits == null || innerSupplier == null) {
+                return null;
+            }
+
+            return new ScorerSupplier() {
+                @Override
+                public Scorer get(long leadCost) throws IOException {
+                    var innerScorer = innerSupplier.get(leadCost);
+                    var innerIterator = innerScorer.iterator();
+                    return new Scorer() {
+                        int currentDoc = -1;
+                        float currentScore = Float.NaN;
+
+                        @Override
+                        public int docID() {
+                            return currentDoc;
+                        }
+
+                        @Override
+                        public DocIdSetIterator iterator() {
+                            return new DocIdSetIterator() {
+                                boolean exhausted = false;
+
+                                @Override
+                                public int docID() {
+                                    return currentDoc;
+                                }
+
+                                @Override
+                                public int nextDoc() throws IOException {
+                                    return advance(currentDoc + 1);
+                                }
+
+                                @Override
+                                public int advance(int target) throws IOException {
+                                    if (exhausted) {
+                                        return currentDoc = NO_MORE_DOCS;
+                                    }
+                                    if (currentDoc == -1 || innerIterator.docID() < target) {
+                                        if (innerIterator.advance(target) == NO_MORE_DOCS) {
+                                            exhausted = true;
+                                            return currentDoc = NO_MORE_DOCS;
+                                        }
+                                    }
+
+                                    int bestChild = innerIterator.docID();
+                                    float bestScore = innerScorer.score();
+                                    int parent = parentBits.nextSetBit(bestChild);
+
+                                    int innerDoc;
+                                    while ((innerDoc = innerIterator.nextDoc()) < parent) {
+                                        float score = innerScorer.score();
+                                        if (score > bestScore) {
+                                            bestChild = innerIterator.docID();
+                                            bestScore = score;
+                                        }
+                                    }
+                                    if (innerDoc == NO_MORE_DOCS) {
+                                        exhausted = true;
+                                    }
+                                    currentScore = bestScore;
+                                    return currentDoc = bestChild;
+                                }
+
+                                @Override
+                                public long cost() {
+                                    return innerIterator.cost();
+                                }
+                            };
+                        }
+
+                        @Override
+                        public float score() throws IOException {
+                            return currentScore;
+                        }
+
+                        @Override
+                        public float getMaxScore(int upTo) throws IOException {
+                            return innerScorer.getMaxScore(upTo);
+                        }
+                    };
+                }
+
+                @Override
+                public long cost() {
+                    return innerSupplier.cost();
+                }
+            };
+        }
+
+        @Override
+        public boolean isCacheable(LeafReaderContext ctx) {
+            return false;
+        }
+    }
+}

+ 128 - 22
server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java

@@ -12,8 +12,9 @@ package org.elasticsearch.search.vectors;
 import org.apache.lucene.index.VectorSimilarityFunction;
 import org.apache.lucene.queries.function.FunctionScoreQuery;
 import org.apache.lucene.search.BooleanClause;
-import org.apache.lucene.search.DoubleValuesSource;
 import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.KnnByteVectorQuery;
+import org.apache.lucene.search.KnnFloatVectorQuery;
 import org.apache.lucene.search.Query;
 import org.apache.lucene.search.QueryVisitor;
 import org.apache.lucene.search.TopDocs;
@@ -25,17 +26,28 @@ import java.util.Arrays;
 import java.util.Objects;
 
 /**
- * Wraps an internal query to rescore the results using a similarity function over the original, non-quantized vectors of a vector field
+ * A Lucene {@link Query} that applies vector-based rescoring to an inner query's results.
+ * <p>
+ * Depending on the nature of the {@code innerQuery}, this class dynamically selects between two rescoring strategies:
+ * <ul>
+ *   <li><b>Inline rescoring</b>:
+ *   Used when the inner query is already a top-N vector query with {@code rescoreK} results.
+ *   The vector similarity is applied inline using a {@link FunctionScoreQuery} without an additional
+ *   filtering pass.</li>
+ *   <li><b>Late rescoring</b>: Used when the inner query is not a top-N vector query or does not return
+ *   {@code rescoreK} results. The top {@code rescoreK} documents are first collected, and then rescoring is applied
+ *   separately to select the final top {@code k}.</li>
+ * </ul>
  */
-public class RescoreKnnVectorQuery extends Query implements QueryProfilerProvider {
-    private final String fieldName;
-    private final float[] floatTarget;
-    private final VectorSimilarityFunction vectorSimilarityFunction;
-    private final int k;
-    private final Query innerQuery;
-    private long vectorOperations = 0;
-
-    public RescoreKnnVectorQuery(
+public abstract class RescoreKnnVectorQuery extends Query implements QueryProfilerProvider {
+    protected final String fieldName;
+    protected final float[] floatTarget;
+    protected final VectorSimilarityFunction vectorSimilarityFunction;
+    protected final int k;
+    protected final Query innerQuery;
+    protected long vectorOperations = 0;
+
+    private RescoreKnnVectorQuery(
         String fieldName,
         float[] floatTarget,
         VectorSimilarityFunction vectorSimilarityFunction,
@@ -49,16 +61,31 @@ public class RescoreKnnVectorQuery extends Query implements QueryProfilerProvide
         this.innerQuery = innerQuery;
     }
 
-    @Override
-    public Query rewrite(IndexSearcher searcher) throws IOException {
-        DoubleValuesSource valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction);
-        FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(innerQuery, valueSource);
-        Query query = searcher.rewrite(functionScoreQuery);
-
-        // Retrieve top k documents from the rescored query
-        TopDocs topDocs = searcher.search(query, k);
-        vectorOperations = topDocs.totalHits.value();
-        return new KnnScoreDocQuery(topDocs.scoreDocs, searcher.getIndexReader());
+    /**
+     * Selects and returns the appropriate {@link RescoreKnnVectorQuery} strategy based on the nature of the {@code innerQuery}.
+     *
+     * @param fieldName                 the name of the field containing the vector
+     * @param floatTarget              the target vector to compare against
+     * @param vectorSimilarityFunction the similarity function to apply
+     * @param k                        the number of top documents to return after rescoring
+     * @param rescoreK                 the number of top documents to consider for rescoring
+     * @param innerQuery               the original Lucene query to rescore
+     */
+    public static RescoreKnnVectorQuery fromInnerQuery(
+        String fieldName,
+        float[] floatTarget,
+        VectorSimilarityFunction vectorSimilarityFunction,
+        int k,
+        int rescoreK,
+        Query innerQuery
+    ) {
+        if ((innerQuery instanceof KnnFloatVectorQuery fQuery && fQuery.getK() == rescoreK)
+            || (innerQuery instanceof KnnByteVectorQuery bQuery && bQuery.getK() == rescoreK)
+            || (innerQuery instanceof AbstractIVFKnnVectorQuery ivfQuery && ivfQuery.k == rescoreK)) {
+            // Queries that return only the top `k` results and do not require reduction before re-scoring.
+            return new InlineRescoreQuery(fieldName, floatTarget, vectorSimilarityFunction, k, innerQuery);
+        }
+        return new LateRescoreQuery(fieldName, floatTarget, vectorSimilarityFunction, k, rescoreK, innerQuery);
     }
 
     public Query innerQuery() {
@@ -102,7 +129,8 @@ public class RescoreKnnVectorQuery extends Query implements QueryProfilerProvide
 
     @Override
     public String toString(String field) {
-        return "KnnRescoreVectorQuery{"
+        return getClass().getSimpleName()
+            + "{"
             + "fieldName='"
             + fieldName
             + '\''
@@ -117,4 +145,82 @@ public class RescoreKnnVectorQuery extends Query implements QueryProfilerProvide
             + innerQuery
             + '}';
     }
+
+    private static class InlineRescoreQuery extends RescoreKnnVectorQuery {
+        private InlineRescoreQuery(
+            String fieldName,
+            float[] floatTarget,
+            VectorSimilarityFunction vectorSimilarityFunction,
+            int k,
+            Query innerQuery
+        ) {
+            super(fieldName, floatTarget, vectorSimilarityFunction, k, innerQuery);
+        }
+
+        @Override
+        public Query rewrite(IndexSearcher searcher) throws IOException {
+            var valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction);
+            var functionScoreQuery = new FunctionScoreQuery(innerQuery, valueSource);
+            // Retrieve top k documents from the function score query
+            var topDocs = searcher.search(functionScoreQuery, k);
+            vectorOperations = topDocs.totalHits.value();
+            return new KnnScoreDocQuery(topDocs.scoreDocs, searcher.getIndexReader());
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            return super.equals(o);
+        }
+
+        @Override
+        public int hashCode() {
+            return super.hashCode();
+        }
+    }
+
+    private static class LateRescoreQuery extends RescoreKnnVectorQuery {
+        final int rescoreK;
+
+        private LateRescoreQuery(
+            String fieldName,
+            float[] floatTarget,
+            VectorSimilarityFunction vectorSimilarityFunction,
+            int k,
+            int rescoreK,
+            Query innerQuery
+        ) {
+            super(fieldName, floatTarget, vectorSimilarityFunction, k, innerQuery);
+            this.rescoreK = rescoreK;
+        }
+
+        @Override
+        public Query rewrite(IndexSearcher searcher) throws IOException {
+            final TopDocs topDocs;
+            // Retrieve top `rescoreK` documents from the inner query
+            topDocs = searcher.search(innerQuery, rescoreK);
+            vectorOperations = topDocs.totalHits.value();
+
+            // Retrieve top `k` documents from the top `rescoreK` query
+            var topDocsQuery = new KnnScoreDocQuery(topDocs.scoreDocs, searcher.getIndexReader());
+            var valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction);
+            var rescoreQuery = new FunctionScoreQuery(topDocsQuery, valueSource);
+            var rescoreTopDocs = searcher.search(rescoreQuery.rewrite(searcher), k);
+            return new KnnScoreDocQuery(rescoreTopDocs.scoreDocs, searcher.getIndexReader());
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            var that = (RescoreKnnVectorQuery.LateRescoreQuery) o;
+            return super.equals(o) && that.rescoreK == rescoreK;
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(super.hashCode(), rescoreK);
+        }
+    }
 }

+ 2 - 8
server/src/test/java/org/elasticsearch/index/mapper/DynamicMappingTests.java

@@ -992,16 +992,10 @@ public class DynamicMappingTests extends MapperServiceTestCase {
         assertThat(((FieldMapper) update.getRoot().getMapper("mapsToFloatTooBig")).fieldType().typeName(), equalTo("float"));
         assertThat(((FieldMapper) update.getRoot().getMapper("mapsToInt8HnswDenseVector")).fieldType().typeName(), equalTo("dense_vector"));
         DenseVectorFieldMapper int8DVFieldMapper = ((DenseVectorFieldMapper) update.getRoot().getMapper("mapsToInt8HnswDenseVector"));
-        assertThat(
-            ((DenseVectorFieldMapper.DenseVectorIndexOptions) int8DVFieldMapper.fieldType().getIndexOptions()).getType().getName(),
-            equalTo("int8_hnsw")
-        );
+        assertThat(int8DVFieldMapper.fieldType().getIndexOptions().getType().getName(), equalTo("int8_hnsw"));
         assertThat(((FieldMapper) update.getRoot().getMapper("mapsToBBQHnswDenseVector")).fieldType().typeName(), equalTo("dense_vector"));
         DenseVectorFieldMapper bbqDVFieldMapper = ((DenseVectorFieldMapper) update.getRoot().getMapper("mapsToBBQHnswDenseVector"));
-        assertThat(
-            ((DenseVectorFieldMapper.DenseVectorIndexOptions) bbqDVFieldMapper.fieldType().getIndexOptions()).getType().getName(),
-            equalTo("bbq_hnsw")
-        );
+        assertThat(bbqDVFieldMapper.fieldType().getIndexOptions().getType().getName(), equalTo("bbq_hnsw"));
     }
 
     public void testDefaultDenseVectorMappingsObject() throws IOException {

+ 45 - 12
server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java

@@ -25,6 +25,7 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.DenseVector
 import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.VectorSimilarity;
 import org.elasticsearch.search.DocValueFormat;
 import org.elasticsearch.search.vectors.DenseVectorQuery;
+import org.elasticsearch.search.vectors.DiversifyingParentBlockQuery;
 import org.elasticsearch.search.vectors.ESKnnByteVectorQuery;
 import org.elasticsearch.search.vectors.ESKnnFloatVectorQuery;
 import org.elasticsearch.search.vectors.RescoreKnnVectorQuery;
@@ -238,7 +239,11 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
             if (query instanceof RescoreKnnVectorQuery rescoreKnnVectorQuery) {
                 query = rescoreKnnVectorQuery.innerQuery();
             }
-            assertThat(query, instanceOf(DiversifyingChildrenFloatKnnVectorQuery.class));
+            if (field.getIndexOptions().isFlat()) {
+                assertThat(query, instanceOf(DiversifyingParentBlockQuery.class));
+            } else {
+                assertThat(query, instanceOf(DiversifyingChildrenFloatKnnVectorQuery.class));
+            }
         }
         {
             DenseVectorFieldType field = new DenseVectorFieldType(
@@ -269,7 +274,11 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
                 producer,
                 randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
             );
-            assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class));
+            if (field.getIndexOptions().isFlat()) {
+                assertThat(query, instanceOf(DiversifyingParentBlockQuery.class));
+            } else {
+                assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class));
+            }
 
             vectorData = new VectorData(floatQueryVector, null);
             query = field.createKnnQuery(
@@ -282,7 +291,11 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
                 producer,
                 randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
             );
-            assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class));
+            if (field.getIndexOptions().isFlat()) {
+                assertThat(query, instanceOf(DiversifyingParentBlockQuery.class));
+            } else {
+                assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class));
+            }
         }
     }
 
@@ -445,7 +458,11 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
             if (query instanceof RescoreKnnVectorQuery rescoreKnnVectorQuery) {
                 query = rescoreKnnVectorQuery.innerQuery();
             }
-            assertThat(query, instanceOf(KnnFloatVectorQuery.class));
+            if (fieldWith4096dims.getIndexOptions().isFlat()) {
+                assertThat(query, instanceOf(DenseVectorQuery.Floats.class));
+            } else {
+                assertThat(query, instanceOf(KnnFloatVectorQuery.class));
+            }
         }
 
         {   // byte type with 4096 dims
@@ -475,7 +492,11 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
                 null,
                 randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
             );
-            assertThat(query, instanceOf(KnnByteVectorQuery.class));
+            if (fieldWith4096dims.getIndexOptions().isFlat()) {
+                assertThat(query, instanceOf(DenseVectorQuery.Bytes.class));
+            } else {
+                assertThat(query, instanceOf(KnnByteVectorQuery.class));
+            }
         }
     }
 
@@ -574,13 +595,21 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
         );
 
         if (elementType == BYTE) {
-            ESKnnByteVectorQuery esKnnQuery = (ESKnnByteVectorQuery) knnQuery;
-            assertThat(esKnnQuery.getK(), is(100));
-            assertThat(esKnnQuery.kParam(), is(10));
+            if (nonQuantizedField.getIndexOptions().isFlat()) {
+                assertThat(knnQuery, instanceOf(DenseVectorQuery.Bytes.class));
+            } else {
+                ESKnnByteVectorQuery esKnnQuery = (ESKnnByteVectorQuery) knnQuery;
+                assertThat(esKnnQuery.getK(), is(100));
+                assertThat(esKnnQuery.kParam(), is(10));
+            }
         } else {
-            ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) knnQuery;
-            assertThat(esKnnQuery.getK(), is(100));
-            assertThat(esKnnQuery.kParam(), is(10));
+            if (nonQuantizedField.getIndexOptions().isFlat()) {
+                assertThat(knnQuery, instanceOf(DenseVectorQuery.Floats.class));
+            } else {
+                ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) knnQuery;
+                assertThat(esKnnQuery.getK(), is(100));
+                assertThat(esKnnQuery.kParam(), is(10));
+            }
         }
     }
 
@@ -628,7 +657,11 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
             null,
             randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
         );
-        assertTrue(query instanceof ESKnnFloatVectorQuery);
+        if (fieldType.getIndexOptions().isFlat()) {
+            assertThat(query, instanceOf(DenseVectorQuery.Floats.class));
+        } else {
+            assertThat(query, instanceOf(ESKnnFloatVectorQuery.class));
+        }
 
         // verify we can override a `0` to a positive number
         fieldType = new DenseVectorFieldType(

+ 34 - 3
server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java

@@ -51,6 +51,7 @@ import java.util.stream.Stream;
 import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.DEFAULT_OVERSAMPLE;
 import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.OVERSAMPLE_LIMIT;
 import static org.elasticsearch.search.SearchService.DEFAULT_SIZE;
+import static org.hamcrest.Matchers.anyOf;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.hasSize;
@@ -203,8 +204,14 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
             }
         }
         switch (elementType()) {
-            case FLOAT -> assertTrue(query instanceof ESKnnFloatVectorQuery);
-            case BYTE -> assertTrue(query instanceof ESKnnByteVectorQuery);
+            case FLOAT -> assertThat(
+                query,
+                anyOf(instanceOf(ESKnnFloatVectorQuery.class), instanceOf(DenseVectorQuery.Floats.class), instanceOf(BooleanQuery.class))
+            );
+            case BYTE -> assertThat(
+                query,
+                anyOf(instanceOf(ESKnnByteVectorQuery.class), instanceOf(DenseVectorQuery.Bytes.class), instanceOf(BooleanQuery.class))
+            );
         }
 
         BooleanQuery.Builder builder = new BooleanQuery.Builder();
@@ -244,10 +251,34 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
                 expectedStrategy
             );
         };
+
+        Query bruteForceVectorQueryBuilt = switch (elementType()) {
+            case BIT, BYTE -> {
+                if (filterQuery != null) {
+                    yield new BooleanQuery.Builder().add(
+                        new DenseVectorQuery.Bytes(queryBuilder.queryVector().asByteVector(), VECTOR_FIELD),
+                        BooleanClause.Occur.SHOULD
+                    ).add(filterQuery, BooleanClause.Occur.FILTER).build();
+                } else {
+                    yield new DenseVectorQuery.Bytes(queryBuilder.queryVector().asByteVector(), VECTOR_FIELD);
+                }
+            }
+            case FLOAT -> {
+                if (filterQuery != null) {
+                    yield new BooleanQuery.Builder().add(
+                        new DenseVectorQuery.Floats(queryBuilder.queryVector().asFloatVector(), VECTOR_FIELD),
+                        BooleanClause.Occur.SHOULD
+                    ).add(filterQuery, BooleanClause.Occur.FILTER).build();
+                } else {
+                    yield new DenseVectorQuery.Floats(queryBuilder.queryVector().asFloatVector(), VECTOR_FIELD);
+                }
+            }
+        };
+
         if (query instanceof VectorSimilarityQuery vectorSimilarityQuery) {
             query = vectorSimilarityQuery.getInnerKnnQuery();
         }
-        assertEquals(query, knnVectorQueryBuilt);
+        assertThat(query, anyOf(equalTo(knnVectorQueryBuilt), equalTo(bruteForceVectorQueryBuilt)));
     }
 
     public void testWrongDimension() {

+ 162 - 0
server/src/test/java/org/elasticsearch/search/vectors/DiversifyingParentBlockQueryTests.java

@@ -0,0 +1,162 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.search.vectors;
+
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.join.ScoreMode;
+import org.apache.lucene.search.join.ToParentBlockJoinQuery;
+import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.index.mapper.MapperServiceTestCase;
+import org.elasticsearch.index.mapper.ParsedDocument;
+import org.elasticsearch.index.mapper.SourceToParse;
+import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentType;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Locale;
+import java.util.Set;
+import java.util.TreeMap;
+
+import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.instanceOf;
+
+public class DiversifyingParentBlockQueryTests extends MapperServiceTestCase {
+    private static String getMapping(int dim) {
+        return String.format(Locale.ROOT, """
+                {
+                  "_doc": {
+                    "properties": {
+                      "id": {
+                        "type": "keyword",
+                        "store": true
+                      },
+                      "nested": {
+                        "type": "nested",
+                          "properties": {
+                            "emb": {
+                              "type": "dense_vector",
+                              "dims": %d,
+                              "similarity": "l2_norm",
+                              "index_options": {
+                                "type": "flat"
+                              }
+                            }
+                          }
+                        }
+                      }
+                    }
+                  }
+                }
+            """, dim);
+    }
+
+    public void testRandom() throws IOException {
+        int dims = randomIntBetween(3, 10);
+        var mapperService = createMapperService(getMapping(dims));
+        var fieldType = (DenseVectorFieldMapper.DenseVectorFieldType) mapperService.fieldType("nested.emb");
+        var nestedParent = mapperService.mappingLookup().nestedLookup().getNestedMappers().get("nested");
+
+        int numQueries = randomIntBetween(1, 3);
+        float[][] queries = new float[numQueries][];
+        List<TreeMap<Float, String>> expectedTopDocs = new ArrayList<>();
+        for (int i = 0; i < numQueries; i++) {
+            queries[i] = randomVector(dims);
+            expectedTopDocs.add(new TreeMap<>((o1, o2) -> -Float.compare(o1, o2)));
+        }
+
+        withLuceneIndex(mapperService, iw -> {
+            int numDocs = randomIntBetween(10, 50);
+            for (int i = 0; i < numDocs; i++) {
+                int numVectors = randomIntBetween(0, 5);
+                float[][] vectors = new float[numVectors][];
+                for (int j = 0; j < numVectors; j++) {
+                    vectors[j] = randomVector(dims);
+                }
+
+                for (int k = 0; k < numQueries; k++) {
+                    float maxScore = Float.MIN_VALUE;
+                    for (int j = 0; j < numVectors; j++) {
+                        float score = EUCLIDEAN.compare(vectors[j], queries[k]);
+                        maxScore = Math.max(score, maxScore);
+                    }
+                    expectedTopDocs.get(k).put(maxScore, Integer.toString(i));
+                }
+
+                SourceToParse source = randomSource(Integer.toString(i), vectors);
+                ParsedDocument doc = mapperService.documentMapper().parse(source);
+                iw.addDocuments(doc.docs());
+
+                if (randomBoolean()) {
+                    int numEmpty = randomIntBetween(1, 3);
+                    for (int l = 0; l < numEmpty; l++) {
+                        source = randomSource(randomAlphaOfLengthBetween(5, 10), new float[0][]);
+                        doc = mapperService.documentMapper().parse(source);
+                        iw.addDocuments(doc.docs());
+                    }
+                }
+            }
+        }, ir -> {
+            var storedFields = ir.storedFields();
+            var searcher = new IndexSearcher(wrapInMockESDirectoryReader(ir));
+            var context = createSearchExecutionContext(mapperService);
+            var bitSetproducer = context.bitsetFilter(nestedParent.parentTypeFilter());
+            for (int i = 0; i < numQueries; i++) {
+                var knnQuery = fieldType.createKnnQuery(
+                    VectorData.fromFloats(queries[i]),
+                    10,
+                    10,
+                    null,
+                    null,
+                    null,
+                    bitSetproducer,
+                    DenseVectorFieldMapper.FilterHeuristic.ACORN
+                );
+                assertThat(knnQuery, instanceOf(DiversifyingParentBlockQuery.class));
+                var nestedQuery = new ToParentBlockJoinQuery(knnQuery, bitSetproducer, ScoreMode.Total);
+                var topDocs = searcher.search(nestedQuery, 10);
+                for (var doc : topDocs.scoreDocs) {
+                    var entry = expectedTopDocs.get(i).pollFirstEntry();
+                    assertNotNull(entry);
+                    assertThat(doc.score, equalTo(entry.getKey()));
+                    var storedDoc = storedFields.document(doc.doc, Set.of("id"));
+                    assertThat(storedDoc.getField("id").binaryValue().utf8ToString(), equalTo(entry.getValue()));
+                }
+            }
+        });
+    }
+
+    private SourceToParse randomSource(String id, float[][] vectors) throws IOException {
+        try (var builder = XContentBuilder.builder(XContentType.JSON.xContent())) {
+            builder.startObject();
+            builder.field("id", id);
+            builder.startArray("nested");
+            for (int i = 0; i < vectors.length; i++) {
+                builder.startObject();
+                builder.field("emb", vectors[i]);
+                builder.endObject();
+            }
+            builder.endArray();
+            builder.endObject();
+            return new SourceToParse(id, BytesReference.bytes(builder), XContentType.JSON);
+        }
+    }
+
+    private float[] randomVector(int dim) {
+        float[] vector = new float[dim];
+        for (int i = 0; i < vector.length; i++) {
+            vector[i] = randomFloat();
+        }
+        return vector;
+    }
+}

+ 52 - 42
server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java

@@ -18,7 +18,10 @@ import org.apache.lucene.index.IndexWriter;
 import org.apache.lucene.index.IndexWriterConfig;
 import org.apache.lucene.index.VectorSimilarityFunction;
 import org.apache.lucene.queries.function.FunctionScoreQuery;
+import org.apache.lucene.search.BooleanClause;
+import org.apache.lucene.search.BooleanQuery;
 import org.apache.lucene.search.DoubleValuesSource;
+import org.apache.lucene.search.FieldExistsQuery;
 import org.apache.lucene.search.IndexSearcher;
 import org.apache.lucene.search.KnnFloatVectorQuery;
 import org.apache.lucene.search.MatchAllDocsQuery;
@@ -41,6 +44,8 @@ import org.elasticsearch.test.ESTestCase;
 
 import java.io.IOException;
 import java.io.UnsupportedEncodingException;
+import java.util.ArrayList;
+import java.util.List;
 
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.greaterThan;
@@ -54,53 +59,57 @@ public class RescoreKnnVectorQueryTests extends ESTestCase {
         int numDims = randomIntBetween(5, 100);
         int k = randomIntBetween(1, numDocs - 1);
 
+        var queryVector = randomVector(numDims);
+        List<Query> innerQueries = new ArrayList<>();
+        innerQueries.add(new KnnFloatVectorQuery(FIELD_NAME, randomVector(numDims), (int) (k * randomFloatBetween(1.0f, 10.0f, true))));
+        innerQueries.add(
+            new BooleanQuery.Builder().add(new DenseVectorQuery.Floats(queryVector, FIELD_NAME), BooleanClause.Occur.SHOULD)
+                .add(new FieldExistsQuery(FIELD_NAME), BooleanClause.Occur.FILTER)
+                .build()
+        );
+        innerQueries.add(new MatchAllDocsQuery());
+
         try (Directory d = newDirectory()) {
             addRandomDocuments(numDocs, d, numDims);
 
             try (IndexReader reader = DirectoryReader.open(d)) {
 
-                // Use a RescoreKnnVectorQuery with a match all query, to ensure we get scoring of 1 from the inner query
-                // and thus we're rescoring the top k docs.
-                float[] queryVector = randomVector(numDims);
-                Query innerQuery;
-                if (randomBoolean()) {
-                    innerQuery = new KnnFloatVectorQuery(FIELD_NAME, queryVector, (int) (k * randomFloatBetween(1.0f, 10.0f, true)));
-                } else {
-                    innerQuery = new MatchAllDocsQuery();
-                }
-                RescoreKnnVectorQuery rescoreKnnVectorQuery = new RescoreKnnVectorQuery(
-                    FIELD_NAME,
-                    queryVector,
-                    VectorSimilarityFunction.COSINE,
-                    k,
-                    innerQuery
-                );
-
-                IndexSearcher searcher = newSearcher(reader, true, false);
-                TopDocs rescoredDocs = searcher.search(rescoreKnnVectorQuery, numDocs);
-                assertThat(rescoredDocs.scoreDocs.length, equalTo(k));
-
-                // Get real scores
-                DoubleValuesSource valueSource = new VectorSimilarityFloatValueSource(
-                    FIELD_NAME,
-                    queryVector,
-                    VectorSimilarityFunction.COSINE
-                );
-                FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(new MatchAllDocsQuery(), valueSource);
-                TopDocs realScoreTopDocs = searcher.search(functionScoreQuery, numDocs);
-
-                int i = 0;
-                ScoreDoc[] realScoreDocs = realScoreTopDocs.scoreDocs;
-                for (ScoreDoc rescoreDoc : rescoredDocs.scoreDocs) {
-                    // There are docs that won't be found in the rescored search, but every doc found must be in the same order
-                    // and have the same score
-                    while (i < realScoreDocs.length && realScoreDocs[i].doc != rescoreDoc.doc) {
-                        i++;
-                    }
-                    if (i >= realScoreDocs.length) {
-                        fail("Rescored doc not found in real score docs");
+                for (var innerQuery : innerQueries) {
+                    RescoreKnnVectorQuery rescoreKnnVectorQuery = RescoreKnnVectorQuery.fromInnerQuery(
+                        FIELD_NAME,
+                        queryVector,
+                        VectorSimilarityFunction.COSINE,
+                        k,
+                        k,
+                        innerQuery
+                    );
+
+                    IndexSearcher searcher = newSearcher(reader, true, false);
+                    TopDocs rescoredDocs = searcher.search(rescoreKnnVectorQuery, numDocs);
+                    assertThat(rescoredDocs.scoreDocs.length, equalTo(k));
+
+                    // Get real scores
+                    DoubleValuesSource valueSource = new VectorSimilarityFloatValueSource(
+                        FIELD_NAME,
+                        queryVector,
+                        VectorSimilarityFunction.COSINE
+                    );
+                    FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(new MatchAllDocsQuery(), valueSource);
+                    TopDocs realScoreTopDocs = searcher.search(functionScoreQuery, numDocs);
+
+                    int i = 0;
+                    ScoreDoc[] realScoreDocs = realScoreTopDocs.scoreDocs;
+                    for (ScoreDoc rescoreDoc : rescoredDocs.scoreDocs) {
+                        // There are docs that won't be found in the rescored search, but every doc found must be in the same order
+                        // and have the same score
+                        while (i < realScoreDocs.length && realScoreDocs[i].doc != rescoreDoc.doc) {
+                            i++;
+                        }
+                        if (i >= realScoreDocs.length) {
+                            fail("Rescored doc not found in real score docs");
+                        }
+                        assertThat("Real score is not the same as rescored score", rescoreDoc.score, equalTo(realScoreDocs[i].score));
                     }
-                    assertThat("Real score is not the same as rescored score", rescoreDoc.score, equalTo(realScoreDocs[i].score));
                 }
             }
         }
@@ -124,11 +133,12 @@ public class RescoreKnnVectorQueryTests extends ESTestCase {
     }
 
     private void checkProfiling(int k, int numDocs, float[] queryVector, IndexReader reader, Query innerQuery) throws IOException {
-        RescoreKnnVectorQuery rescoreKnnVectorQuery = new RescoreKnnVectorQuery(
+        var rescoreKnnVectorQuery = RescoreKnnVectorQuery.fromInnerQuery(
             FIELD_NAME,
             queryVector,
             VectorSimilarityFunction.COSINE,
             k,
+            k,
             innerQuery
         );
         IndexSearcher searcher = newSearcher(reader, true, false);