浏览代码

Make `knn` search requests fully cancellable (#90612)

Approximate nearest-neighbor search requests should be cancelled if the overall search request is cancelled. Currently, the cancellation is not respected by searches provided in the knn section of the search request.

This commit ports over the ExitableDirectoryReader changes made in Lucene. The main differences (besides formatting) is changing the timeout checker to be our query cancellation checker.

Related to: apache/lucene#833

closes: #90323
Benjamin Trent 3 年之前
父节点
当前提交
2a17e5302d

+ 5 - 0
docs/changelog/90612.yaml

@@ -0,0 +1,5 @@
+pr: 90612
+summary: Make `knn` search requests fully cancellable
+area: Vector Search
+type: bug
+issues: []

+ 91 - 0
server/src/main/java/org/elasticsearch/search/internal/ExitableDirectoryReader.java

@@ -12,14 +12,18 @@ import org.apache.lucene.codecs.StoredFieldsReader;
 import org.apache.lucene.index.DirectoryReader;
 import org.apache.lucene.index.FilterDirectoryReader;
 import org.apache.lucene.index.FilterLeafReader;
+import org.apache.lucene.index.FilterVectorValues;
 import org.apache.lucene.index.IndexReader;
 import org.apache.lucene.index.LeafReader;
 import org.apache.lucene.index.PointValues;
 import org.apache.lucene.index.QueryTimeout;
 import org.apache.lucene.index.Terms;
 import org.apache.lucene.index.TermsEnum;
+import org.apache.lucene.index.VectorValues;
 import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.TopDocs;
 import org.apache.lucene.search.suggest.document.CompletionTerms;
+import org.apache.lucene.util.Bits;
 import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.automaton.CompiledAutomaton;
 import org.elasticsearch.common.lucene.index.SequentialStoredFieldsLeafReader;
@@ -29,6 +33,8 @@ import java.io.IOException;
 /**
  * Wraps an {@link IndexReader} with a {@link QueryCancellation}
  * which checks for cancelled or timed-out query.
+ * Note: this class was adapted from Lucene's ExitableDirectoryReader, but instead of using a query timeout for cancellation,
+ *       a {@link QueryCancellation} object is used. The main behavior of the classes is mostly unchanged.
  */
 class ExitableDirectoryReader extends FilterDirectoryReader {
 
@@ -119,6 +125,45 @@ class ExitableDirectoryReader extends FilterDirectoryReader {
         protected StoredFieldsReader doGetSequentialStoredFieldsReader(StoredFieldsReader reader) {
             return reader;
         }
+
+        @Override
+        public VectorValues getVectorValues(String field) throws IOException {
+            VectorValues vectorValues = in.getVectorValues(field);
+            if (vectorValues == null) {
+                return null;
+            }
+            return queryCancellation.isEnabled() ? new ExitableVectorValues(vectorValues, queryCancellation) : vectorValues;
+        }
+
+        @Override
+        public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
+            if (queryCancellation.isEnabled() == false) {
+                return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
+            }
+            // when acceptDocs is null due to no doc deleted, we will instantiate a new one that would
+            // match all docs to allow timeout checking.
+            final Bits updatedAcceptDocs = acceptDocs == null ? new Bits.MatchAllBits(maxDoc()) : acceptDocs;
+            Bits timeoutCheckingAcceptDocs = new Bits() {
+                private static final int MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK = 10;
+                private int calls;
+
+                @Override
+                public boolean get(int index) {
+                    if (calls++ % MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK == 0) {
+                        queryCancellation.checkCancelled();
+                    }
+
+                    return updatedAcceptDocs.get(index);
+                }
+
+                @Override
+                public int length() {
+                    return updatedAcceptDocs.length();
+                }
+            };
+
+            return in.searchNearestVectors(field, target, k, timeoutCheckingAcceptDocs, visitedLimit);
+        }
     }
 
     /**
@@ -377,4 +422,50 @@ class ExitableDirectoryReader extends FilterDirectoryReader {
             in.grow(count);
         }
     }
+
+    private static class ExitableVectorValues extends FilterVectorValues {
+        private static final int DOCS_BETWEEN_TIMEOUT_CHECK = 1000;
+        private int docToCheck;
+        private final QueryCancellation queryCancellation;
+
+        ExitableVectorValues(VectorValues vectorValues, QueryCancellation queryCancellation) {
+            super(vectorValues);
+            docToCheck = 0;
+            this.queryCancellation = queryCancellation;
+            this.queryCancellation.checkCancelled();
+        }
+
+        @Override
+        public int advance(int target) throws IOException {
+            final int advance = super.advance(target);
+            checkAndThrowWithSampling(advance);
+            return advance;
+        }
+
+        @Override
+        public int nextDoc() throws IOException {
+            final int nextDoc = super.nextDoc();
+            checkAndThrowWithSampling(nextDoc);
+            return nextDoc;
+        }
+
+        private void checkAndThrowWithSampling(int nextDoc) {
+            if (nextDoc >= docToCheck) {
+                this.queryCancellation.checkCancelled();
+                docToCheck = nextDoc + DOCS_BETWEEN_TIMEOUT_CHECK;
+            }
+        }
+
+        @Override
+        public float[] vectorValue() throws IOException {
+            this.queryCancellation.checkCancelled();
+            return in.vectorValue();
+        }
+
+        @Override
+        public BytesRef binaryValue() throws IOException {
+            this.queryCancellation.checkCancelled();
+            return in.binaryValue();
+        }
+    }
 }

+ 43 - 0
server/src/test/java/org/elasticsearch/search/SearchCancellationTests.java

@@ -10,12 +10,15 @@ package org.elasticsearch.search;
 import org.apache.lucene.document.Document;
 import org.apache.lucene.document.Field;
 import org.apache.lucene.document.IntPoint;
+import org.apache.lucene.document.KnnVectorField;
 import org.apache.lucene.document.StringField;
 import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.LeafReader;
 import org.apache.lucene.index.NoMergePolicy;
 import org.apache.lucene.index.PointValues;
 import org.apache.lucene.index.Terms;
 import org.apache.lucene.index.TermsEnum;
+import org.apache.lucene.index.VectorValues;
 import org.apache.lucene.search.IndexSearcher;
 import org.apache.lucene.search.MatchAllDocsQuery;
 import org.apache.lucene.search.TotalHitCountCollector;
@@ -40,6 +43,7 @@ public class SearchCancellationTests extends ESTestCase {
 
     private static final String STRING_FIELD_NAME = "foo";
     private static final String POINT_FIELD_NAME = "point";
+    private static final String KNN_FIELD_NAME = "vector";
 
     private static Directory dir;
     private static IndexReader reader;
@@ -63,6 +67,7 @@ public class SearchCancellationTests extends ESTestCase {
             Document doc = new Document();
             doc.add(new StringField(STRING_FIELD_NAME, "a".repeat(i), Field.Store.NO));
             doc.add(new IntPoint(POINT_FIELD_NAME, i, i + 1));
+            doc.add(new KnnVectorField(KNN_FIELD_NAME, new float[] { 1.0f, 0.5f, 42.0f }));
             w.addDocument(doc);
         }
     }
@@ -176,6 +181,44 @@ public class SearchCancellationTests extends ESTestCase {
         pointValues2.intersect(new PointValuesIntersectVisitor());
     }
 
+    public void testExitableDirectoryReaderVectors() throws IOException {
+        AtomicBoolean cancelled = new AtomicBoolean(true);
+        Runnable cancellation = () -> {
+            if (cancelled.get()) {
+                throw new TaskCancelledException("cancelled");
+            }
+        };
+        ContextIndexSearcher searcher = new ContextIndexSearcher(
+            reader,
+            IndexSearcher.getDefaultSimilarity(),
+            IndexSearcher.getDefaultQueryCache(),
+            IndexSearcher.getDefaultQueryCachingPolicy(),
+            true
+        );
+        searcher.addQueryCancellation(cancellation);
+        final LeafReader leaf = searcher.getIndexReader().leaves().get(0).reader();
+        expectThrows(TaskCancelledException.class, () -> leaf.getVectorValues(KNN_FIELD_NAME));
+        expectThrows(
+            TaskCancelledException.class,
+            () -> leaf.searchNearestVectors(KNN_FIELD_NAME, new float[] { 1f, 1f, 1f }, 2, leaf.getLiveDocs(), Integer.MAX_VALUE)
+        );
+
+        cancelled.set(false); // Avoid exception during construction of the wrapper objects
+        VectorValues vectorValues = searcher.getIndexReader().leaves().get(0).reader().getVectorValues(KNN_FIELD_NAME);
+        cancelled.set(true);
+        expectThrows(TaskCancelledException.class, vectorValues::nextDoc);
+        expectThrows(TaskCancelledException.class, vectorValues::vectorValue);
+        expectThrows(TaskCancelledException.class, vectorValues::binaryValue);
+
+        cancelled.set(false); // Avoid exception during construction of the wrapper objects
+        VectorValues uncancelledVectorValues = searcher.getIndexReader().leaves().get(0).reader().getVectorValues(KNN_FIELD_NAME);
+        cancelled.set(true);
+        searcher.removeQueryCancellation(cancellation);
+        uncancelledVectorValues.nextDoc();
+        uncancelledVectorValues.vectorValue();
+        uncancelledVectorValues.binaryValue();
+    }
+
     private static class PointValuesIntersectVisitor implements PointValues.IntersectVisitor {
         @Override
         public void visit(int docID) {}