|
@@ -9,20 +9,32 @@
|
|
|
|
|
|
package org.elasticsearch.search.vectors;
|
|
|
|
|
|
+import org.apache.lucene.index.FloatVectorValues;
|
|
|
+import org.apache.lucene.index.KnnVectorValues;
|
|
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
|
|
import org.apache.lucene.queries.function.FunctionScoreQuery;
|
|
|
import org.apache.lucene.search.BooleanClause;
|
|
|
+import org.apache.lucene.search.ConjunctionUtils;
|
|
|
+import org.apache.lucene.search.DocAndFloatFeatureBuffer;
|
|
|
+import org.apache.lucene.search.DocIdSetIterator;
|
|
|
import org.apache.lucene.search.IndexSearcher;
|
|
|
import org.apache.lucene.search.KnnByteVectorQuery;
|
|
|
import org.apache.lucene.search.KnnFloatVectorQuery;
|
|
|
+import org.apache.lucene.search.MatchAllDocsQuery;
|
|
|
+import org.apache.lucene.search.MatchNoDocsQuery;
|
|
|
import org.apache.lucene.search.Query;
|
|
|
import org.apache.lucene.search.QueryVisitor;
|
|
|
+import org.apache.lucene.search.ScoreDoc;
|
|
|
+import org.apache.lucene.search.ScoreMode;
|
|
|
import org.apache.lucene.search.TopDocs;
|
|
|
-import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource;
|
|
|
+import org.elasticsearch.index.codec.vectors.BulkScorableFloatVectorValues;
|
|
|
+import org.elasticsearch.index.codec.vectors.BulkScorableVectorValues;
|
|
|
import org.elasticsearch.search.profile.query.QueryProfiler;
|
|
|
|
|
|
import java.io.IOException;
|
|
|
+import java.util.ArrayList;
|
|
|
import java.util.Arrays;
|
|
|
+import java.util.List;
|
|
|
import java.util.Objects;
|
|
|
|
|
|
/**
|
|
@@ -159,10 +171,8 @@ public abstract class RescoreKnnVectorQuery extends Query implements QueryProfil
|
|
|
|
|
|
@Override
|
|
|
public Query rewrite(IndexSearcher searcher) throws IOException {
|
|
|
- var valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction);
|
|
|
- var functionScoreQuery = new FunctionScoreQuery(innerQuery, valueSource);
|
|
|
- // Retrieve top k documents from the function score query
|
|
|
- var topDocs = searcher.search(functionScoreQuery, k);
|
|
|
+ var rescoreQuery = new DirectRescoreKnnVectorQuery(fieldName, floatTarget, innerQuery);
|
|
|
+ var topDocs = searcher.search(rescoreQuery, k);
|
|
|
vectorOperations = topDocs.totalHits.value();
|
|
|
return new KnnScoreDocQuery(topDocs.scoreDocs, searcher.getIndexReader());
|
|
|
}
|
|
@@ -204,8 +214,7 @@ public abstract class RescoreKnnVectorQuery extends Query implements QueryProfil
|
|
|
|
|
|
// Retrieve top `k` documents from the top `rescoreK` query
|
|
|
var topDocsQuery = new KnnScoreDocQuery(topDocs.scoreDocs, searcher.getIndexReader());
|
|
|
- var valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction);
|
|
|
- var rescoreQuery = new FunctionScoreQuery(topDocsQuery, valueSource);
|
|
|
+ var rescoreQuery = new DirectRescoreKnnVectorQuery(fieldName, floatTarget, topDocsQuery);
|
|
|
var rescoreTopDocs = searcher.search(rescoreQuery.rewrite(searcher), k);
|
|
|
return new KnnScoreDocQuery(rescoreTopDocs.scoreDocs, searcher.getIndexReader());
|
|
|
}
|
|
@@ -223,4 +232,125 @@ public abstract class RescoreKnnVectorQuery extends Query implements QueryProfil
|
|
|
return Objects.hash(super.hashCode(), rescoreK);
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ private static class DirectRescoreKnnVectorQuery extends Query {
|
|
|
+ private final float[] floatTarget;
|
|
|
+ private final String fieldName;
|
|
|
+ private final Query innerQuery;
|
|
|
+
|
|
|
+ DirectRescoreKnnVectorQuery(String fieldName, float[] floatTarget, Query innerQuery) {
|
|
|
+ this.fieldName = fieldName;
|
|
|
+ this.floatTarget = floatTarget;
|
|
|
+ this.innerQuery = innerQuery;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public String toString(String field) {
|
|
|
+ return "DirectRescoreKnnVectorQuery[" + innerQuery + "]";
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
|
|
|
+ Query innerRewritten = innerQuery.rewrite(indexSearcher);
|
|
|
+ if (innerRewritten.getClass() == MatchNoDocsQuery.class) {
|
|
|
+ return new MatchNoDocsQuery();
|
|
|
+ }
|
|
|
+ assert innerRewritten.getClass() != MatchAllDocsQuery.class;
|
|
|
+
|
|
|
+ List<ScoreDoc> results = new ArrayList<>(10);
|
|
|
+ for (var leaf : indexSearcher.getIndexReader().leaves()) {
|
|
|
+ var knnVectorValues = leaf.reader().getFloatVectorValues(fieldName);
|
|
|
+ if (knnVectorValues == null) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ if (knnVectorValues.dimension() != floatTarget.length) {
|
|
|
+ throw new IllegalArgumentException(
|
|
|
+ "vector query dimension: " + floatTarget.length + " differs from field dimension: " + knnVectorValues.dimension()
|
|
|
+ );
|
|
|
+ }
|
|
|
+ var weight = innerRewritten.createWeight(indexSearcher, ScoreMode.COMPLETE_NO_SCORES, 1.0f);
|
|
|
+ var scorer = weight.scorer(leaf);
|
|
|
+ if (scorer == null) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ var filterIterator = scorer.iterator();
|
|
|
+ if (knnVectorValues instanceof BulkScorableFloatVectorValues rescorableVectorValues) {
|
|
|
+ rescoreBulk(leaf.docBase, rescorableVectorValues, results, filterIterator);
|
|
|
+ } else {
|
|
|
+ rescoreIndividually(
|
|
|
+ leaf.docBase,
|
|
|
+ knnVectorValues,
|
|
|
+ leaf.reader().getFieldInfos().fieldInfo(fieldName).getVectorSimilarityFunction(),
|
|
|
+ results,
|
|
|
+ filterIterator
|
|
|
+ );
|
|
|
+ }
|
|
|
+ }
|
|
|
+ // Remove any remaining sentinel values
|
|
|
+ ScoreDoc[] arrayResults = results.toArray(new ScoreDoc[0]);
|
|
|
+ return new KnnScoreDocQuery(arrayResults, indexSearcher.getIndexReader());
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void visit(QueryVisitor visitor) {
|
|
|
+ if (visitor.acceptField(fieldName)) {
|
|
|
+ visitor.visitLeaf(this);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public boolean equals(Object obj) {
|
|
|
+ if (this == obj) return true;
|
|
|
+ if (obj == null || getClass() != obj.getClass()) return false;
|
|
|
+ DirectRescoreKnnVectorQuery that = (DirectRescoreKnnVectorQuery) obj;
|
|
|
+ return Objects.equals(innerQuery, that.innerQuery);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public int hashCode() {
|
|
|
+ return Objects.hash(innerQuery, getClass());
|
|
|
+ }
|
|
|
+
|
|
|
+ private void rescoreBulk(
|
|
|
+ int docBase,
|
|
|
+ BulkScorableFloatVectorValues rescorableVectorValues,
|
|
|
+ List<ScoreDoc> queue,
|
|
|
+ DocIdSetIterator filterIterator
|
|
|
+ ) throws IOException {
|
|
|
+ BulkScorableVectorValues.BulkVectorScorer vectorReScorer = rescorableVectorValues.rescorer(floatTarget);
|
|
|
+ var iterator = vectorReScorer.iterator();
|
|
|
+ BulkScorableVectorValues.BulkVectorScorer.Bulk bulkScorer = vectorReScorer.bulk(filterIterator);
|
|
|
+ DocAndFloatFeatureBuffer buffer = new DocAndFloatFeatureBuffer();
|
|
|
+ while (iterator.docID() != DocIdSetIterator.NO_MORE_DOCS) {
|
|
|
+ // iterator already takes live docs into account
|
|
|
+ bulkScorer.nextDocsAndScores(64, null, buffer);
|
|
|
+ for (int i = 0; i < buffer.size; i++) {
|
|
|
+ float score = buffer.features[i];
|
|
|
+ int doc = buffer.docs[i];
|
|
|
+ queue.add(new ScoreDoc(doc + docBase, score));
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ private void rescoreIndividually(
|
|
|
+ int docBase,
|
|
|
+ FloatVectorValues knnVectorValues,
|
|
|
+ VectorSimilarityFunction function,
|
|
|
+ List<ScoreDoc> queue,
|
|
|
+ DocIdSetIterator filterIterator
|
|
|
+ ) throws IOException {
|
|
|
+ int doc;
|
|
|
+ KnnVectorValues.DocIndexIterator knnVectorIterator = knnVectorValues.iterator();
|
|
|
+ var conjunction = ConjunctionUtils.intersectIterators(List.of(knnVectorIterator, filterIterator));
|
|
|
+ while ((doc = conjunction.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
|
|
|
+ assert doc == knnVectorIterator.docID();
|
|
|
+ float[] vector = knnVectorValues.vectorValue(knnVectorIterator.index());
|
|
|
+ float score = function.compare(floatTarget, vector);
|
|
|
+ if (Float.isNaN(score)) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ queue.add(new ScoreDoc(doc + docBase, score));
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|