Przeglądaj źródła

Wrap ES KNN queries with PatienceKNN query (#127223)

Tommaso Teofili 3 miesięcy temu
rodzic
commit
9edfa6642a

+ 5 - 0
docs/changelog/127223.yaml

@@ -0,0 +1,5 @@
+pr: 127223
+summary: Wrap ES KNN queries with PatienceKNN query
+area: Vector Search
+type: feature
+issues: []

+ 3 - 0
docs/reference/elasticsearch/index-settings/index-modules.md

@@ -259,3 +259,6 @@ $$$index-esql-stored-fields-sequential-proportion$$$
 
 `index.esql.stored_fields_sequential_proportion`
 :   Tuning parameter for deciding when {{esql}} will load [Stored fields](/reference/elasticsearch/rest-apis/retrieve-selected-fields.md#stored-fields) using a strategy tuned for loading dense sequence of documents. Allows values between 0.0 and 1.0 and defaults to 0.2. Indices with documents smaller than 10kb may see speed improvements loading `text` fields by setting this lower.
+
+$$$index-dense-vector-hnsw-early-termination$$$ `index.dense_vector.hnsw_early_termination`
+:   Whether to apply _patience_ based early termination strategy to knn queries over HNSW graphs (see [paper](https://cs.uwaterloo.ca/~jimmylin/publications/Teofili_Lin_ECIR2025.pdf)). This is only applicable to `dense_vector` fields with `hnsw`, `int8_hnsw`, `int4_hnsw` and `bbq_hnsw` index types. Defaults to `false`.

+ 12 - 2
qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java

@@ -48,7 +48,8 @@ record CmdLineArgs(
     VectorSimilarityFunction vectorSpace,
     int quantizeBits,
     VectorEncoding vectorEncoding,
-    int dimensions
+    int dimensions,
+    boolean earlyTermination
 ) implements ToXContentObject {
 
     static final ParseField DOC_VECTORS_FIELD = new ParseField("doc_vectors");
@@ -71,6 +72,7 @@ record CmdLineArgs(
     static final ParseField QUANTIZE_BITS_FIELD = new ParseField("quantize_bits");
     static final ParseField VECTOR_ENCODING_FIELD = new ParseField("vector_encoding");
     static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions");
+    static final ParseField EARLY_TERMINATION_FIELD = new ParseField("early_termination");
 
     static CmdLineArgs fromXContent(XContentParser parser) throws IOException {
         Builder builder = PARSER.apply(parser, null);
@@ -100,6 +102,7 @@ record CmdLineArgs(
         PARSER.declareInt(Builder::setQuantizeBits, QUANTIZE_BITS_FIELD);
         PARSER.declareString(Builder::setVectorEncoding, VECTOR_ENCODING_FIELD);
         PARSER.declareInt(Builder::setDimensions, DIMENSIONS_FIELD);
+        PARSER.declareBoolean(Builder::setEarlyTermination, EARLY_TERMINATION_FIELD);
     }
 
     @Override
@@ -158,6 +161,7 @@ record CmdLineArgs(
         private int quantizeBits = 8;
         private VectorEncoding vectorEncoding = VectorEncoding.FLOAT32;
         private int dimensions;
+        private boolean earlyTermination;
 
         public Builder setDocVectors(String docVectors) {
             this.docVectors = PathUtils.get(docVectors);
@@ -259,6 +263,11 @@ record CmdLineArgs(
             return this;
         }
 
+        public Builder setEarlyTermination(Boolean patience) {
+            this.earlyTermination = patience;
+            return this;
+        }
+
         public CmdLineArgs build() {
             if (docVectors == null) {
                 throw new IllegalArgumentException("Document vectors path must be provided");
@@ -288,7 +297,8 @@ record CmdLineArgs(
                 vectorSpace,
                 quantizeBits,
                 vectorEncoding,
-                dimensions
+                dimensions,
+                earlyTermination
             );
         }
     }

+ 1 - 1
qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java

@@ -211,7 +211,7 @@ public class KnnIndexTester {
                 for (int i = 0; i < results.length; i++) {
                     int nProbe = nProbes[i];
                     KnnSearcher knnSearcher = new KnnSearcher(indexPath, cmdLineArgs, nProbe);
-                    knnSearcher.runSearch(results[i]);
+                    knnSearcher.runSearch(results[i], cmdLineArgs.earlyTermination());
                 }
             }
             formattedResults.results.addAll(List.of(results));

+ 22 - 10
qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java

@@ -33,6 +33,9 @@ import org.apache.lucene.queries.function.valuesource.ConstKnnFloatValueSource;
 import org.apache.lucene.queries.function.valuesource.FloatKnnVectorFieldSource;
 import org.apache.lucene.queries.function.valuesource.FloatVectorSimilarityFunction;
 import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.KnnByteVectorQuery;
+import org.apache.lucene.search.KnnFloatVectorQuery;
+import org.apache.lucene.search.PatienceKnnVectorQuery;
 import org.apache.lucene.search.Query;
 import org.apache.lucene.search.ScoreDoc;
 import org.apache.lucene.search.TopDocs;
@@ -114,7 +117,7 @@ class KnnSearcher {
         this.searchThreads = cmdLineArgs.searchThreads();
     }
 
-    void runSearch(KnnIndexTester.Results finalResults) throws IOException {
+    void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) throws IOException {
         TopDocs[] results = new TopDocs[numQueryVectors];
         int[][] resultIds = new int[numQueryVectors][];
         long elapsed, totalCpuTimeMS, totalVisited = 0;
@@ -153,10 +156,10 @@ class KnnSearcher {
                     for (int i = 0; i < numQueryVectors; i++) {
                         if (vectorEncoding.equals(VectorEncoding.BYTE)) {
                             targetReader.next(targetBytes);
-                            doVectorQuery(targetBytes, searcher);
+                            doVectorQuery(targetBytes, searcher, earlyTermination);
                         } else {
                             targetReader.next(target);
-                            doVectorQuery(target, searcher);
+                            doVectorQuery(target, searcher, earlyTermination);
                         }
                     }
                     targetReader.reset();
@@ -165,10 +168,10 @@ class KnnSearcher {
                     for (int i = 0; i < numQueryVectors; i++) {
                         if (vectorEncoding.equals(VectorEncoding.BYTE)) {
                             targetReader.next(targetBytes);
-                            results[i] = doVectorQuery(targetBytes, searcher);
+                            results[i] = doVectorQuery(targetBytes, searcher, earlyTermination);
                         } else {
                             targetReader.next(target);
-                            results[i] = doVectorQuery(target, searcher);
+                            results[i] = doVectorQuery(target, searcher, earlyTermination);
                         }
                     }
                     KnnIndexTester.ThreadDetails endThreadDetails = new KnnIndexTester.ThreadDetails();
@@ -264,7 +267,7 @@ class KnnSearcher {
         return true;
     }
 
-    TopDocs doVectorQuery(byte[] vector, IndexSearcher searcher) throws IOException {
+    TopDocs doVectorQuery(byte[] vector, IndexSearcher searcher, boolean earlyTermination) throws IOException {
         Query knnQuery;
         if (overSamplingFactor > 1f) {
             throw new IllegalArgumentException("oversampling factor > 1 is not supported for byte vectors");
@@ -280,6 +283,9 @@ class KnnSearcher {
                 null,
                 DenseVectorFieldMapper.FilterHeuristic.ACORN.getKnnSearchStrategy()
             );
+            if (indexType == KnnIndexTester.IndexType.HNSW && earlyTermination) {
+                knnQuery = PatienceKnnVectorQuery.fromByteQuery((KnnByteVectorQuery) knnQuery);
+            }
         }
         QueryProfiler profiler = new QueryProfiler();
         TopDocs docs = searcher.search(knnQuery, this.topK);
@@ -288,7 +294,7 @@ class KnnSearcher {
         return new TopDocs(new TotalHits(profiler.getVectorOpsCount(), docs.totalHits.relation()), docs.scoreDocs);
     }
 
-    TopDocs doVectorQuery(float[] vector, IndexSearcher searcher) throws IOException {
+    TopDocs doVectorQuery(float[] vector, IndexSearcher searcher, boolean earlyTermination) throws IOException {
         Query knnQuery;
         int topK = this.topK;
         if (overSamplingFactor > 1f) {
@@ -307,6 +313,9 @@ class KnnSearcher {
                 null,
                 DenseVectorFieldMapper.FilterHeuristic.ACORN.getKnnSearchStrategy()
             );
+            if (indexType == KnnIndexTester.IndexType.HNSW && earlyTermination) {
+                knnQuery = PatienceKnnVectorQuery.fromFloatQuery((KnnFloatVectorQuery) knnQuery);
+            }
         }
         if (overSamplingFactor > 1f) {
             // oversample the topK results to get more candidates for the final result
@@ -314,9 +323,12 @@ class KnnSearcher {
         }
         QueryProfiler profiler = new QueryProfiler();
         TopDocs docs = searcher.search(knnQuery, this.topK);
-        QueryProfilerProvider queryProfilerProvider = (QueryProfilerProvider) knnQuery;
-        queryProfilerProvider.profile(profiler);
-        return new TopDocs(new TotalHits(profiler.getVectorOpsCount(), docs.totalHits.relation()), docs.scoreDocs);
+        if (knnQuery instanceof QueryProfilerProvider queryProfilerProvider) {
+            queryProfilerProvider.profile(profiler);
+            return new TopDocs(new TotalHits(profiler.getVectorOpsCount(), docs.totalHits.relation()), docs.scoreDocs);
+        } else {
+            return docs;
+        }
     }
 
     private static float checkResults(int[][] results, int[][] nn, int topK) {

+ 51 - 0
server/src/internalClusterTest/java/org/elasticsearch/search/query/VectorIT.java

@@ -127,4 +127,55 @@ public class VectorIT extends ESIntegTestCase {
         });
     }
 
+    public void testHnswEarlyTerminationQuery() {
+        float[] vector = new float[16];
+        randomVector(vector, 25);
+        int upperLimit = 35;
+        var query = new KnnSearchBuilder(VECTOR_FIELD, vector, 1, 1, null, null);
+        assertResponse(client().prepareSearch(INDEX_NAME).setKnnSearch(List.of(query)).setSize(1).setProfile(true), response -> {
+            assertNotEquals(0, response.getHits().getHits().length);
+            var profileResults = response.getProfileResults();
+            long vectorOpsSum = profileResults.values()
+                .stream()
+                .mapToLong(
+                    pr -> pr.getQueryPhase()
+                        .getSearchProfileDfsPhaseResult()
+                        .getQueryProfileShardResult()
+                        .stream()
+                        .mapToLong(qpr -> qpr.getVectorOperationsCount().longValue())
+                        .sum()
+                )
+                .sum();
+            client().admin()
+                .indices()
+                .prepareUpdateSettings(INDEX_NAME)
+                .setSettings(Settings.builder().put(DenseVectorFieldMapper.HNSW_EARLY_TERMINATION.getKey(), true))
+                .get();
+            assertResponse(
+                client().prepareSearch(INDEX_NAME).setKnnSearch(List.of(query)).setSize(1).setProfile(true),
+                earlyTerminationResponse -> {
+                    assertNotEquals(0, earlyTerminationResponse.getHits().getHits().length);
+                    var earlyTerminationResults = earlyTerminationResponse.getProfileResults();
+                    long earlyTerminationVectorOpsSum = earlyTerminationResults.values()
+                        .stream()
+                        .mapToLong(
+                            pr -> pr.getQueryPhase()
+                                .getSearchProfileDfsPhaseResult()
+                                .getQueryProfileShardResult()
+                                .stream()
+                                .mapToLong(qpr -> qpr.getVectorOperationsCount().longValue())
+                                .sum()
+                        )
+                        .sum();
+                    assertTrue(
+                        "earlyTerminationVectorOps [" + earlyTerminationVectorOpsSum + "] is not lt vectorOps [" + vectorOpsSum + "]",
+                        earlyTerminationVectorOpsSum < vectorOpsSum
+                            // if both switch to brute-force due to excessive exploration, they will both equal to upperLimit
+                            || (earlyTerminationVectorOpsSum == vectorOpsSum && vectorOpsSum == upperLimit + 1)
+                    );
+                }
+            );
+        });
+    }
+
 }

+ 1 - 0
server/src/main/java/org/elasticsearch/common/settings/IndexScopedSettings.java

@@ -159,6 +159,7 @@ public final class IndexScopedSettings extends AbstractScopedSettings {
                 IndexSettings.INDEX_TRANSLOG_RETENTION_SIZE_SETTING,
                 IndexSettings.INDEX_SEARCH_IDLE_AFTER,
                 DenseVectorFieldMapper.HNSW_FILTER_HEURISTIC,
+                DenseVectorFieldMapper.HNSW_EARLY_TERMINATION,
                 IndexFieldDataService.INDEX_FIELDDATA_CACHE_KEY,
                 IndexSettings.IGNORE_ABOVE_SETTING,
                 FieldMapper.IGNORE_MALFORMED_SETTING,

+ 11 - 0
server/src/main/java/org/elasticsearch/index/IndexSettings.java

@@ -916,6 +916,7 @@ public final class IndexSettings {
     private volatile int maxNgramDiff;
     private volatile int maxShingleDiff;
     private volatile DenseVectorFieldMapper.FilterHeuristic hnswFilterHeuristic;
+    private volatile boolean earlyTermination;
     private volatile TimeValue searchIdleAfter;
     private volatile int maxAnalyzedOffset;
     private volatile boolean weightMatchesEnabled;
@@ -1113,6 +1114,7 @@ public final class IndexSettings {
         skipIgnoredSourceWrite = scopedSettings.get(IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_WRITE_SETTING);
         skipIgnoredSourceRead = scopedSettings.get(IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_READ_SETTING);
         hnswFilterHeuristic = scopedSettings.get(DenseVectorFieldMapper.HNSW_FILTER_HEURISTIC);
+        earlyTermination = scopedSettings.get(DenseVectorFieldMapper.HNSW_EARLY_TERMINATION);
         indexMappingSourceMode = scopedSettings.get(INDEX_MAPPER_SOURCE_MODE_SETTING);
         recoverySourceEnabled = RecoverySettings.INDICES_RECOVERY_SOURCE_ENABLED_SETTING.get(nodeSettings);
         recoverySourceSyntheticEnabled = DiscoveryNode.isStateless(nodeSettings) == false
@@ -1227,6 +1229,7 @@ public final class IndexSettings {
         );
         scopedSettings.addSettingsUpdateConsumer(IgnoredSourceFieldMapper.SKIP_IGNORED_SOURCE_READ_SETTING, this::setSkipIgnoredSourceRead);
         scopedSettings.addSettingsUpdateConsumer(DenseVectorFieldMapper.HNSW_FILTER_HEURISTIC, this::setHnswFilterHeuristic);
+        scopedSettings.addSettingsUpdateConsumer(DenseVectorFieldMapper.HNSW_EARLY_TERMINATION, this::setHnswEarlyTermination);
     }
 
     private void setSearchIdleAfter(TimeValue searchIdleAfter) {
@@ -1858,6 +1861,14 @@ public final class IndexSettings {
         this.hnswFilterHeuristic = heuristic;
     }
 
+    public boolean getHnswEarlyTermination() {
+        return this.earlyTermination;
+    }
+
+    private void setHnswEarlyTermination(boolean earlyTermination) {
+        this.earlyTermination = earlyTermination;
+    }
+
     public SeqNoFieldMapper.SeqNoIndexOptions seqNoIndexOptions() {
         return seqNoIndexOptions;
     }

+ 54 - 10
server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

@@ -33,6 +33,9 @@ import org.apache.lucene.index.VectorSimilarityFunction;
 import org.apache.lucene.search.BooleanClause;
 import org.apache.lucene.search.BooleanQuery;
 import org.apache.lucene.search.FieldExistsQuery;
+import org.apache.lucene.search.KnnByteVectorQuery;
+import org.apache.lucene.search.KnnFloatVectorQuery;
+import org.apache.lucene.search.PatienceKnnVectorQuery;
 import org.apache.lucene.search.Query;
 import org.apache.lucene.search.join.BitSetProducer;
 import org.apache.lucene.search.knn.KnnSearchStrategy;
@@ -123,6 +126,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
     private static final float EPS = 1e-3f;
     public static final int BBQ_MIN_DIMS = 64;
 
+    private static final boolean DEFAULT_HNSW_EARLY_TERMINATION = false;
     public static final FeatureFlag IVF_FORMAT = new FeatureFlag("ivf_format");
 
     public static boolean isNotUnitVector(float magnitude) {
@@ -174,6 +178,14 @@ public class DenseVectorFieldMapper extends FieldMapper {
         Setting.Property.Dynamic
     );
 
+    public static final Setting<Boolean> HNSW_EARLY_TERMINATION = Setting.boolSetting(
+        "index.dense_vector.hnsw_enable_early_termination",
+        DEFAULT_HNSW_EARLY_TERMINATION,
+        Setting.Property.IndexScope,
+        Setting.Property.ServerlessPublic,
+        Setting.Property.Dynamic
+    );
+
     private static boolean hasRescoreIndexVersion(IndexVersion version) {
         return version.onOrAfter(IndexVersions.ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS)
             || version.between(IndexVersions.ADD_RESCORE_PARAMS_TO_QUANTIZED_VECTORS_BACKPORT_8_X, IndexVersions.UPGRADE_TO_LUCENE_10_0_0);
@@ -212,7 +224,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
     public static final int MAX_DIMS_COUNT_BIT = 4096 * Byte.SIZE; // maximum allowed number of dimensions
 
     public static final short MIN_DIMS_FOR_DYNAMIC_FLOAT_MAPPING = 128; // minimum number of dims for floats to be dynamically mapped to
-                                                                        // vector
+    // vector
     public static final int MAGNITUDE_BYTES = 4;
     public static final int OVERSAMPLE_LIMIT = 10_000; // Max oversample allowed
     public static final float DEFAULT_OVERSAMPLE = 3.0F; // Default oversample value
@@ -1429,6 +1441,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
                 if (efConstructionNode == null) {
                     efConstructionNode = Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
                 }
+
                 int m = XContentMapValues.nodeIntegerValue(mNode);
                 int efConstruction = XContentMapValues.nodeIntegerValue(efConstructionNode);
                 MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap);
@@ -2487,7 +2500,8 @@ public class DenseVectorFieldMapper extends FieldMapper {
             Query filter,
             Float similarityThreshold,
             BitSetProducer parentFilter,
-            DenseVectorFieldMapper.FilterHeuristic heuristic
+            FilterHeuristic heuristic,
+            boolean hnswEarlyTermination
         ) {
             if (isIndexed() == false) {
                 throw new IllegalArgumentException(
@@ -2503,7 +2517,8 @@ public class DenseVectorFieldMapper extends FieldMapper {
                     filter,
                     similarityThreshold,
                     parentFilter,
-                    knnSearchStrategy
+                    knnSearchStrategy,
+                    hnswEarlyTermination
                 );
                 case FLOAT -> createKnnFloatQuery(
                     queryVector.asFloatVector(),
@@ -2513,7 +2528,8 @@ public class DenseVectorFieldMapper extends FieldMapper {
                     filter,
                     similarityThreshold,
                     parentFilter,
-                    knnSearchStrategy
+                    knnSearchStrategy,
+                    hnswEarlyTermination
                 );
                 case BIT -> createKnnBitQuery(
                     queryVector.asByteVector(),
@@ -2522,7 +2538,8 @@ public class DenseVectorFieldMapper extends FieldMapper {
                     filter,
                     similarityThreshold,
                     parentFilter,
-                    knnSearchStrategy
+                    knnSearchStrategy,
+                    hnswEarlyTermination
                 );
             };
         }
@@ -2542,7 +2559,8 @@ public class DenseVectorFieldMapper extends FieldMapper {
             Query filter,
             Float similarityThreshold,
             BitSetProducer parentFilter,
-            KnnSearchStrategy searchStrategy
+            KnnSearchStrategy searchStrategy,
+            boolean hnswEarlyTermination
         ) {
             elementType.checkDimensions(dims, queryVector.length);
             Query knnQuery;
@@ -2559,6 +2577,9 @@ public class DenseVectorFieldMapper extends FieldMapper {
                 knnQuery = parentFilter != null
                     ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy)
                     : new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy);
+                if (hnswEarlyTermination) {
+                    knnQuery = maybeWrapPatience(knnQuery);
+                }
             }
             if (similarityThreshold != null) {
                 knnQuery = new VectorSimilarityQuery(
@@ -2577,7 +2598,8 @@ public class DenseVectorFieldMapper extends FieldMapper {
             Query filter,
             Float similarityThreshold,
             BitSetProducer parentFilter,
-            KnnSearchStrategy searchStrategy
+            KnnSearchStrategy searchStrategy,
+            boolean hnswEarlyTermination
         ) {
             elementType.checkDimensions(dims, queryVector.length);
 
@@ -2585,7 +2607,6 @@ public class DenseVectorFieldMapper extends FieldMapper {
                 float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
                 elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude);
             }
-
             Query knnQuery;
             if (indexOptions != null && indexOptions.isFlat()) {
                 var exactKnnQuery = parentFilter != null
@@ -2600,6 +2621,9 @@ public class DenseVectorFieldMapper extends FieldMapper {
                 knnQuery = parentFilter != null
                     ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy)
                     : new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy);
+                if (hnswEarlyTermination) {
+                    knnQuery = maybeWrapPatience(knnQuery);
+                }
             }
             if (similarityThreshold != null) {
                 knnQuery = new VectorSimilarityQuery(
@@ -2611,6 +2635,23 @@ public class DenseVectorFieldMapper extends FieldMapper {
             return knnQuery;
         }
 
+        private Query maybeWrapPatience(Query knnQuery) {
+            Query finalQuery = knnQuery;
+            if (knnQuery instanceof KnnByteVectorQuery knnByteVectorQuery && canApplyPatienceQuery()) {
+                finalQuery = PatienceKnnVectorQuery.fromByteQuery(knnByteVectorQuery);
+            } else if (knnQuery instanceof KnnFloatVectorQuery knnFloatVectorQuery && canApplyPatienceQuery()) {
+                finalQuery = PatienceKnnVectorQuery.fromFloatQuery(knnFloatVectorQuery);
+            }
+            return finalQuery;
+        }
+
+        private boolean canApplyPatienceQuery() {
+            return indexOptions instanceof HnswIndexOptions
+                || indexOptions instanceof Int8HnswIndexOptions
+                || indexOptions instanceof Int4HnswIndexOptions
+                || indexOptions instanceof BBQHnswIndexOptions;
+        }
+
         private Query createKnnFloatQuery(
             float[] queryVector,
             int k,
@@ -2619,7 +2660,8 @@ public class DenseVectorFieldMapper extends FieldMapper {
             Query filter,
             Float similarityThreshold,
             BitSetProducer parentFilter,
-            KnnSearchStrategy knnSearchStrategy
+            KnnSearchStrategy knnSearchStrategy,
+            boolean hnswEarlyTermination
         ) {
             elementType.checkDimensions(dims, queryVector.length);
             elementType.checkVectorBounds(queryVector);
@@ -2686,6 +2728,9 @@ public class DenseVectorFieldMapper extends FieldMapper {
                         knnSearchStrategy
                     )
                     : new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter, knnSearchStrategy);
+                if (hnswEarlyTermination) {
+                    knnQuery = maybeWrapPatience(knnQuery);
+                }
             }
             if (rescore) {
                 knnQuery = RescoreKnnVectorQuery.fromInnerQuery(
@@ -2808,7 +2853,6 @@ public class DenseVectorFieldMapper extends FieldMapper {
         }
         if (fieldType().dims == null) {
             int dims = fieldType().elementType.parseDimensionCount(context);
-            ;
             final boolean defaultInt8Hnsw = indexCreatedVersion.onOrAfter(IndexVersions.DEFAULT_DENSE_VECTOR_TO_INT8_HNSW);
             final boolean defaultBBQ8Hnsw = indexCreatedVersion.onOrAfter(IndexVersions.DEFAULT_DENSE_VECTOR_TO_BBQ_HNSW);
             DenseVectorIndexOptions denseVectorIndexOptions = fieldType().indexOptions;

+ 3 - 1
server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java

@@ -553,6 +553,7 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
             }
         }
         DenseVectorFieldMapper.FilterHeuristic heuristic = context.getIndexSettings().getHnswFilterHeuristic();
+        boolean hnswEarlyTermination = context.getIndexSettings().getHnswEarlyTermination();
         return vectorFieldType.createKnnQuery(
             queryVector,
             k,
@@ -561,7 +562,8 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
             filterQuery,
             vectorSimilarity,
             parentBitSet,
-            heuristic
+            heuristic,
+            hnswEarlyTermination
         );
     }
 

+ 20 - 10
server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java

@@ -2429,7 +2429,8 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
                 null,
                 null,
                 null,
-                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
+                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
+                randomBoolean()
             )
         );
         assertThat(
@@ -2447,7 +2448,8 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
                 null,
                 null,
                 null,
-                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
+                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
+                randomBoolean()
             )
         );
         assertThat(
@@ -2465,7 +2467,8 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
                 null,
                 null,
                 null,
-                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
+                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
+                randomBoolean()
             )
         );
         assertThat(
@@ -2483,7 +2486,8 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
                 null,
                 null,
                 null,
-                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
+                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
+                randomBoolean()
             )
         );
         assertThat(
@@ -2501,7 +2505,8 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
                 null,
                 null,
                 null,
-                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
+                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
+                randomBoolean()
             )
         );
         assertThat(e.getMessage(), containsString("element_type [byte] vectors do not support NaN values but found [NaN] at dim [0];"));
@@ -2516,7 +2521,8 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
                 null,
                 null,
                 null,
-                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
+                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
+                randomBoolean()
             )
         );
         assertThat(
@@ -2534,7 +2540,8 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
                 null,
                 null,
                 null,
-                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
+                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
+                randomBoolean()
             )
         );
         assertThat(
@@ -2569,7 +2576,8 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
                 null,
                 null,
                 null,
-                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
+                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
+                randomBoolean()
             )
         );
         assertThat(e.getMessage(), containsString("element_type [float] vectors do not support NaN values but found [NaN] at dim [0];"));
@@ -2584,7 +2592,8 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
                 null,
                 null,
                 null,
-                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
+                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
+                randomBoolean()
             )
         );
         assertThat(
@@ -2602,7 +2611,8 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
                 null,
                 null,
                 null,
-                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
+                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
+                randomBoolean()
             )
         );
         assertThat(

+ 92 - 56
server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java

@@ -9,8 +9,8 @@
 
 package org.elasticsearch.index.mapper.vectors;
 
-import org.apache.lucene.search.KnnByteVectorQuery;
 import org.apache.lucene.search.KnnFloatVectorQuery;
+import org.apache.lucene.search.PatienceKnnVectorQuery;
 import org.apache.lucene.search.Query;
 import org.apache.lucene.search.join.BitSetProducer;
 import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery;
@@ -234,7 +234,8 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
                 null,
                 null,
                 producer,
-                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
+                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
+                randomBoolean()
             );
             if (query instanceof RescoreKnnVectorQuery rescoreKnnVectorQuery) {
                 query = rescoreKnnVectorQuery.innerQuery();
@@ -242,7 +243,7 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
             if (field.getIndexOptions().isFlat()) {
                 assertThat(query, instanceOf(DiversifyingParentBlockQuery.class));
             } else {
-                assertThat(query, instanceOf(DiversifyingChildrenFloatKnnVectorQuery.class));
+                assertTrue(query instanceof DiversifyingChildrenFloatKnnVectorQuery || query instanceof PatienceKnnVectorQuery);
             }
         }
         {
@@ -272,12 +273,13 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
                 null,
                 null,
                 producer,
-                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
+                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
+                randomBoolean()
             );
             if (field.getIndexOptions().isFlat()) {
                 assertThat(query, instanceOf(DiversifyingParentBlockQuery.class));
             } else {
-                assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class));
+                assertTrue(query instanceof DiversifyingChildrenByteKnnVectorQuery || query instanceof PatienceKnnVectorQuery);
             }
 
             vectorData = new VectorData(floatQueryVector, null);
@@ -289,12 +291,13 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
                 null,
                 null,
                 producer,
-                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
+                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
+                randomBoolean()
             );
             if (field.getIndexOptions().isFlat()) {
                 assertThat(query, instanceOf(DiversifyingParentBlockQuery.class));
             } else {
-                assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class));
+                assertTrue(query instanceof DiversifyingChildrenByteKnnVectorQuery || query instanceof PatienceKnnVectorQuery);
             }
         }
     }
@@ -366,7 +369,8 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
                 null,
                 null,
                 null,
-                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
+                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
+                randomBoolean()
             )
         );
         assertThat(e.getMessage(), containsString("to perform knn search on field [f], its mapping must have [index] set to [true]"));
@@ -396,7 +400,8 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
                 null,
                 null,
                 null,
-                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
+                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
+                randomBoolean()
             )
         );
         assertThat(e.getMessage(), containsString("The [dot_product] similarity can only be used with unit-length vectors."));
@@ -422,7 +427,8 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
                 null,
                 null,
                 null,
-                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
+                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
+                randomBoolean()
             )
         );
         assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude."));
@@ -453,7 +459,8 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
                 null,
                 null,
                 null,
-                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
+                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
+                randomBoolean()
             );
             if (query instanceof RescoreKnnVectorQuery rescoreKnnVectorQuery) {
                 query = rescoreKnnVectorQuery.innerQuery();
@@ -461,7 +468,7 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
             if (fieldWith4096dims.getIndexOptions().isFlat()) {
                 assertThat(query, instanceOf(DenseVectorQuery.Floats.class));
             } else {
-                assertThat(query, instanceOf(KnnFloatVectorQuery.class));
+                assertTrue(query instanceof KnnFloatVectorQuery || query instanceof PatienceKnnVectorQuery);
             }
         }
 
@@ -490,12 +497,13 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
                 null,
                 null,
                 null,
-                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
+                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
+                randomBoolean()
             );
             if (fieldWith4096dims.getIndexOptions().isFlat()) {
                 assertThat(query, instanceOf(DenseVectorQuery.Bytes.class));
             } else {
-                assertThat(query, instanceOf(KnnByteVectorQuery.class));
+                assertTrue(query instanceof ESKnnByteVectorQuery || query instanceof PatienceKnnVectorQuery);
             }
         }
     }
@@ -522,7 +530,8 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
                 null,
                 null,
                 null,
-                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
+                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
+                randomBoolean()
             )
         );
         assertThat(e.getMessage(), containsString("to perform knn search on field [f], its mapping must have [index] set to [true]"));
@@ -548,7 +557,8 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
                 null,
                 null,
                 null,
-                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
+                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
+                randomBoolean()
             )
         );
         assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude."));
@@ -563,7 +573,8 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
                 null,
                 null,
                 null,
-                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
+                randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
+                randomBoolean()
             )
         );
         assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude."));
@@ -591,24 +602,33 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
             null,
             null,
             null,
-            randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
+            randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
+            randomBoolean()
         );
 
         if (elementType == BYTE) {
             if (nonQuantizedField.getIndexOptions().isFlat()) {
                 assertThat(knnQuery, instanceOf(DenseVectorQuery.Bytes.class));
             } else {
-                ESKnnByteVectorQuery esKnnQuery = (ESKnnByteVectorQuery) knnQuery;
-                assertThat(esKnnQuery.getK(), is(100));
-                assertThat(esKnnQuery.kParam(), is(10));
+                if (knnQuery instanceof PatienceKnnVectorQuery patienceKnnVectorQuery) {
+                    assertThat(patienceKnnVectorQuery.getK(), is(100));
+                } else {
+                    ESKnnByteVectorQuery knnByteVectorQuery = (ESKnnByteVectorQuery) knnQuery;
+                    assertThat(knnByteVectorQuery.getK(), is(100));
+                    assertThat(knnByteVectorQuery.kParam(), is(10));
+                }
             }
         } else {
             if (nonQuantizedField.getIndexOptions().isFlat()) {
                 assertThat(knnQuery, instanceOf(DenseVectorQuery.Floats.class));
             } else {
-                ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) knnQuery;
-                assertThat(esKnnQuery.getK(), is(100));
-                assertThat(esKnnQuery.kParam(), is(10));
+                if (knnQuery instanceof PatienceKnnVectorQuery patienceKnnVectorQuery) {
+                    assertThat(patienceKnnVectorQuery.getK(), is(100));
+                } else {
+                    ESKnnFloatVectorQuery knnFloatVectorQuery = (ESKnnFloatVectorQuery) knnQuery;
+                    assertThat(knnFloatVectorQuery.getK(), is(100));
+                    assertThat(knnFloatVectorQuery.kParam(), is(10));
+                }
             }
         }
     }
@@ -655,12 +675,13 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
             null,
             null,
             null,
-            randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
+            randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
+            randomBoolean()
         );
         if (fieldType.getIndexOptions().isFlat()) {
             assertThat(query, instanceOf(DenseVectorQuery.Floats.class));
         } else {
-            assertThat(query, instanceOf(ESKnnFloatVectorQuery.class));
+            assertTrue(query instanceof ESKnnFloatVectorQuery || query instanceof PatienceKnnVectorQuery);
         }
 
         // verify we can override a `0` to a positive number
@@ -683,20 +704,23 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
             null,
             null,
             null,
-            randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
+            randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
+            randomBoolean()
         );
         assertTrue(query instanceof RescoreKnnVectorQuery);
-        assertThat(((RescoreKnnVectorQuery) query).k(), equalTo(10));
-        ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) ((RescoreKnnVectorQuery) query).innerQuery();
-        assertThat(esKnnQuery.kParam(), equalTo(20));
-
+        RescoreKnnVectorQuery rescoreKnnVectorQuery = (RescoreKnnVectorQuery) query;
+        assertThat(rescoreKnnVectorQuery.k(), equalTo(10));
+        Query innerQuery = rescoreKnnVectorQuery.innerQuery();
+        if (innerQuery instanceof ESKnnFloatVectorQuery esKnnFloatVectorQuery) {
+            assertThat(esKnnFloatVectorQuery.kParam(), equalTo(20));
+        }
     }
 
     public void testFilterSearchThreshold() {
         List<Tuple<DenseVectorFieldMapper.ElementType, Function<Query, KnnSearchStrategy>>> cases = List.of(
-            Tuple.tuple(FLOAT, q -> ((ESKnnFloatVectorQuery) q).getStrategy()),
-            Tuple.tuple(BYTE, q -> ((ESKnnByteVectorQuery) q).getStrategy()),
-            Tuple.tuple(BIT, q -> ((ESKnnByteVectorQuery) q).getStrategy())
+            Tuple.tuple(FLOAT, q -> q instanceof PatienceKnnVectorQuery ? null : ((ESKnnFloatVectorQuery) q).getStrategy()),
+            Tuple.tuple(BYTE, q -> q instanceof PatienceKnnVectorQuery ? null : ((ESKnnByteVectorQuery) q).getStrategy()),
+            Tuple.tuple(BIT, q -> q instanceof PatienceKnnVectorQuery ? null : ((ESKnnByteVectorQuery) q).getStrategy())
         );
         for (var tuple : cases) {
             DenseVectorFieldType fieldType = new DenseVectorFieldType(
@@ -720,25 +744,31 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
                 null,
                 null,
                 null,
-                DenseVectorFieldMapper.FilterHeuristic.FANOUT
+                DenseVectorFieldMapper.FilterHeuristic.FANOUT,
+                randomBoolean()
             );
             KnnSearchStrategy strategy = tuple.v2().apply(query);
-            assertTrue(strategy instanceof KnnSearchStrategy.Hnsw);
-            assertThat(((KnnSearchStrategy.Hnsw) strategy).filteredSearchThreshold(), equalTo(0));
-
-            query = fieldType.createKnnQuery(
-                VectorData.fromFloats(new float[] { 1, 4, 10 }),
-                10,
-                100,
-                0f,
-                null,
-                null,
-                null,
-                DenseVectorFieldMapper.FilterHeuristic.ACORN
-            );
-            strategy = tuple.v2().apply(query);
-            assertTrue(strategy instanceof KnnSearchStrategy.Hnsw);
-            assertThat(((KnnSearchStrategy.Hnsw) strategy).filteredSearchThreshold(), equalTo(60));
+            if (strategy != null) {
+                assertTrue(strategy instanceof KnnSearchStrategy.Hnsw);
+                assertThat(((KnnSearchStrategy.Hnsw) strategy).filteredSearchThreshold(), equalTo(0));
+
+                query = fieldType.createKnnQuery(
+                    VectorData.fromFloats(new float[] { 1, 4, 10 }),
+                    10,
+                    100,
+                    0f,
+                    null,
+                    null,
+                    null,
+                    DenseVectorFieldMapper.FilterHeuristic.ACORN,
+                    randomBoolean()
+                );
+                strategy = tuple.v2().apply(query);
+                if (strategy != null) {
+                    assertThat(strategy, instanceOf(KnnSearchStrategy.Hnsw.class));
+                    assertThat(((KnnSearchStrategy.Hnsw) strategy).filteredSearchThreshold(), equalTo(60));
+                }
+            }
         }
     }
 
@@ -759,12 +789,18 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
             null,
             null,
             null,
-            randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
+            randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
+            randomBoolean()
         );
         RescoreKnnVectorQuery rescoreQuery = (RescoreKnnVectorQuery) query;
-        ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) rescoreQuery.innerQuery();
-        assertThat("Unexpected total results", rescoreQuery.k(), equalTo(expectedResults));
-        assertThat("Unexpected k parameter", esKnnQuery.kParam(), equalTo(expectedK));
-        assertThat("Unexpected candidates", esKnnQuery.getK(), equalTo(expectedCandidates));
+        Query innerQuery = rescoreQuery.innerQuery();
+        if (innerQuery instanceof PatienceKnnVectorQuery patienceKnnVectorQuery) {
+            assertThat("Unexpected candidates", patienceKnnVectorQuery.getK(), equalTo(expectedCandidates));
+        } else {
+            ESKnnFloatVectorQuery knnQuery = (ESKnnFloatVectorQuery) innerQuery;
+            assertThat("Unexpected total results", rescoreQuery.k(), equalTo(expectedResults));
+            assertThat("Unexpected candidates", knnQuery.getK(), equalTo(expectedCandidates));
+            assertThat("Unexpected k parameter", knnQuery.kParam(), equalTo(expectedK));
+        }
     }
 }

+ 2 - 1
server/src/test/java/org/elasticsearch/search/vectors/DiversifyingParentBlockQueryTests.java

@@ -120,7 +120,8 @@ public class DiversifyingParentBlockQueryTests extends MapperServiceTestCase {
                     null,
                     null,
                     bitSetproducer,
-                    DenseVectorFieldMapper.FilterHeuristic.ACORN
+                    DenseVectorFieldMapper.FilterHeuristic.ACORN,
+                    randomBoolean()
                 );
                 assertThat(knnQuery, instanceOf(DiversifyingParentBlockQuery.class));
                 var nestedQuery = new ToParentBlockJoinQuery(knnQuery, bitSetproducer, ScoreMode.Total);

+ 2 - 1
x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/TransportResumeFollowAction.java

@@ -531,7 +531,8 @@ public class TransportResumeFollowAction extends AcknowledgedTransportMasterNode
         DataTier.TIER_PREFERENCE_SETTING,
         IndexSettings.BLOOM_FILTER_ID_FIELD_ENABLED_SETTING,
         MetadataIndexStateService.VERIFIED_READ_ONLY_SETTING,
-        DenseVectorFieldMapper.HNSW_FILTER_HEURISTIC
+        DenseVectorFieldMapper.HNSW_FILTER_HEURISTIC,
+        DenseVectorFieldMapper.HNSW_EARLY_TERMINATION
     );
 
     public static Settings filter(Settings originalSettings) {