|
@@ -12,14 +12,18 @@ import org.apache.lucene.codecs.StoredFieldsReader;
|
|
|
import org.apache.lucene.index.DirectoryReader;
|
|
|
import org.apache.lucene.index.FilterDirectoryReader;
|
|
|
import org.apache.lucene.index.FilterLeafReader;
|
|
|
+import org.apache.lucene.index.FilterVectorValues;
|
|
|
import org.apache.lucene.index.IndexReader;
|
|
|
import org.apache.lucene.index.LeafReader;
|
|
|
import org.apache.lucene.index.PointValues;
|
|
|
import org.apache.lucene.index.QueryTimeout;
|
|
|
import org.apache.lucene.index.Terms;
|
|
|
import org.apache.lucene.index.TermsEnum;
|
|
|
+import org.apache.lucene.index.VectorValues;
|
|
|
import org.apache.lucene.search.DocIdSetIterator;
|
|
|
+import org.apache.lucene.search.TopDocs;
|
|
|
import org.apache.lucene.search.suggest.document.CompletionTerms;
|
|
|
+import org.apache.lucene.util.Bits;
|
|
|
import org.apache.lucene.util.BytesRef;
|
|
|
import org.apache.lucene.util.automaton.CompiledAutomaton;
|
|
|
import org.elasticsearch.common.lucene.index.SequentialStoredFieldsLeafReader;
|
|
@@ -29,6 +33,8 @@ import java.io.IOException;
|
|
|
/**
|
|
|
* Wraps an {@link IndexReader} with a {@link QueryCancellation}
|
|
|
* which checks for cancelled or timed-out query.
|
|
|
+ * Note: this class was adapted from Lucene's ExitableDirectoryReader, but instead of using a query timeout for cancellation,
|
|
|
+ * a {@link QueryCancellation} object is used. The main behavior of the classes is mostly unchanged.
|
|
|
*/
|
|
|
class ExitableDirectoryReader extends FilterDirectoryReader {
|
|
|
|
|
@@ -119,6 +125,45 @@ class ExitableDirectoryReader extends FilterDirectoryReader {
|
|
|
protected StoredFieldsReader doGetSequentialStoredFieldsReader(StoredFieldsReader reader) {
|
|
|
return reader;
|
|
|
}
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public VectorValues getVectorValues(String field) throws IOException {
|
|
|
+ VectorValues vectorValues = in.getVectorValues(field);
|
|
|
+ if (vectorValues == null) {
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+ return queryCancellation.isEnabled() ? new ExitableVectorValues(vectorValues, queryCancellation) : vectorValues;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public TopDocs searchNearestVectors(String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
|
|
|
+ if (queryCancellation.isEnabled() == false) {
|
|
|
+ return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
|
|
|
+ }
|
|
|
+ // when acceptDocs is null due to no doc deleted, we will instantiate a new one that would
|
|
|
+ // match all docs to allow timeout checking.
|
|
|
+ final Bits updatedAcceptDocs = acceptDocs == null ? new Bits.MatchAllBits(maxDoc()) : acceptDocs;
|
|
|
+ Bits timeoutCheckingAcceptDocs = new Bits() {
|
|
|
+ private static final int MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK = 10;
|
|
|
+ private int calls;
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public boolean get(int index) {
|
|
|
+ if (calls++ % MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK == 0) {
|
|
|
+ queryCancellation.checkCancelled();
|
|
|
+ }
|
|
|
+
|
|
|
+ return updatedAcceptDocs.get(index);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public int length() {
|
|
|
+ return updatedAcceptDocs.length();
|
|
|
+ }
|
|
|
+ };
|
|
|
+
|
|
|
+ return in.searchNearestVectors(field, target, k, timeoutCheckingAcceptDocs, visitedLimit);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
/**
|
|
@@ -377,4 +422,50 @@ class ExitableDirectoryReader extends FilterDirectoryReader {
|
|
|
in.grow(count);
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ private static class ExitableVectorValues extends FilterVectorValues {
|
|
|
+ private static final int DOCS_BETWEEN_TIMEOUT_CHECK = 1000;
|
|
|
+ private int docToCheck;
|
|
|
+ private final QueryCancellation queryCancellation;
|
|
|
+
|
|
|
+ ExitableVectorValues(VectorValues vectorValues, QueryCancellation queryCancellation) {
|
|
|
+ super(vectorValues);
|
|
|
+ docToCheck = 0;
|
|
|
+ this.queryCancellation = queryCancellation;
|
|
|
+ this.queryCancellation.checkCancelled();
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public int advance(int target) throws IOException {
|
|
|
+ final int advance = super.advance(target);
|
|
|
+ checkAndThrowWithSampling(advance);
|
|
|
+ return advance;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public int nextDoc() throws IOException {
|
|
|
+ final int nextDoc = super.nextDoc();
|
|
|
+ checkAndThrowWithSampling(nextDoc);
|
|
|
+ return nextDoc;
|
|
|
+ }
|
|
|
+
|
|
|
+ private void checkAndThrowWithSampling(int nextDoc) {
|
|
|
+ if (nextDoc >= docToCheck) {
|
|
|
+ this.queryCancellation.checkCancelled();
|
|
|
+ docToCheck = nextDoc + DOCS_BETWEEN_TIMEOUT_CHECK;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public float[] vectorValue() throws IOException {
|
|
|
+ this.queryCancellation.checkCancelled();
|
|
|
+ return in.vectorValue();
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public BytesRef binaryValue() throws IOException {
|
|
|
+ this.queryCancellation.checkCancelled();
|
|
|
+ return in.binaryValue();
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|