Browse Source

New bulk scorer interface for vectors (#135292)

* Removing duplicate logger declaration

* Add new bulk scoring interface and use within vector rescoring

* iter

* fixing empty leaf handling

* removing unused code
Benjamin Trent 2 weeks ago
parent
commit
f0425ba457

+ 29 - 0
server/src/main/java/org/elasticsearch/index/codec/vectors/BulkScorableByteVectorValues.java

@@ -0,0 +1,29 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.index.codec.vectors;
+
+import java.io.IOException;
+
+/**
+ * Extension to {@link BulkScorableVectorValues} for byte[] vectors
+ */
+public interface BulkScorableByteVectorValues extends BulkScorableVectorValues {
+    /**
+     * Returns a {@link BulkVectorScorer} that can score against the provided {@code target} vector.
+     * It will score to the fastest speed possible, potentially sacrificing some fidelity.
+     */
+    BulkVectorScorer scorer(byte[] target) throws IOException;
+
+    /**
+     * Returns a {@link BulkVectorScorer} that can rescore against the provided {@code target} vector.
+     * It will score to the highest fidelity possible, potentially sacrificing some speed.
+     */
+    BulkVectorScorer rescorer(byte[] target) throws IOException;
+}

+ 29 - 0
server/src/main/java/org/elasticsearch/index/codec/vectors/BulkScorableFloatVectorValues.java

@@ -0,0 +1,29 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.index.codec.vectors;
+
+import java.io.IOException;
+
+/**
+ * Extension to {@link BulkScorableVectorValues} for byte[] vectors
+ */
+public interface BulkScorableFloatVectorValues extends BulkScorableVectorValues {
+    /**
+     * Returns a {@link BulkVectorScorer} that can score against the provided {@code target} vector.
+     * It will score to the fastest speed possible, potentially sacrificing some fidelity.
+     */
+    BulkVectorScorer scorer(float[] target) throws IOException;
+
+    /**
+     * Returns a {@link BulkVectorScorer} that can rescore against the provided {@code target} vector.
+     * It will score to the highest fidelity possible, potentially sacrificing some speed.
+     */
+    BulkVectorScorer rescorer(float[] target) throws IOException;
+}

+ 38 - 0
server/src/main/java/org/elasticsearch/index/codec/vectors/BulkScorableVectorValues.java

@@ -0,0 +1,38 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.index.codec.vectors;
+
+import org.apache.lucene.search.DocAndFloatFeatureBuffer;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.VectorScorer;
+import org.apache.lucene.util.Bits;
+
+import java.io.IOException;
+
+/**
+ * Extension to {@link org.apache.lucene.search.VectorScorer} that can score in bulk
+ */
+public interface BulkScorableVectorValues {
+    interface BulkVectorScorer extends VectorScorer {
+
+        /**
+         * Returns a {@link Bulk} scorer that can score in bulk the provided {@code matchingDocs}.
+         */
+        Bulk bulk(DocIdSetIterator matchingDocs) throws IOException;
+
+        interface Bulk {
+            /**
+             * Scores up to {@code nextCount} docs in the provided {@code buffer}.
+             * Returns the maxScore of docs scored.
+             */
+            float nextDocsAndScores(int nextCount, Bits liveDocs, DocAndFloatFeatureBuffer buffer) throws IOException;
+        }
+    }
+}

+ 0 - 96
server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorSimilarityFloatValueSource.java

@@ -1,96 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the "Elastic License
- * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
- * Public License v 1"; you may not use this file except in compliance with, at
- * your election, the "Elastic License 2.0", the "GNU Affero General Public
- * License v3.0 only", or the "Server Side Public License, v 1".
- */
-
-package org.elasticsearch.index.mapper.vectors;
-
-import org.apache.lucene.index.FloatVectorValues;
-import org.apache.lucene.index.KnnVectorValues;
-import org.apache.lucene.index.LeafReader;
-import org.apache.lucene.index.LeafReaderContext;
-import org.apache.lucene.index.VectorSimilarityFunction;
-import org.apache.lucene.search.DocIdSetIterator;
-import org.apache.lucene.search.DoubleValues;
-import org.apache.lucene.search.DoubleValuesSource;
-import org.apache.lucene.search.IndexSearcher;
-
-import java.io.IOException;
-import java.util.Arrays;
-import java.util.Objects;
-
-/**
- * DoubleValuesSource that is used to calculate scores according to a similarity function for a KnnFloatVectorField, using the
- * original vector values stored in the index
- */
-public class VectorSimilarityFloatValueSource extends DoubleValuesSource {
-
-    private final String field;
-    private final float[] target;
-    private final VectorSimilarityFunction vectorSimilarityFunction;
-
-    public VectorSimilarityFloatValueSource(String field, float[] target, VectorSimilarityFunction vectorSimilarityFunction) {
-        this.field = field;
-        this.target = target;
-        this.vectorSimilarityFunction = vectorSimilarityFunction;
-    }
-
-    @Override
-    public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException {
-        final LeafReader reader = ctx.reader();
-
-        FloatVectorValues vectorValues = reader.getFloatVectorValues(field);
-        final KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
-
-        return new DoubleValues() {
-            @Override
-            public double doubleValue() throws IOException {
-                return vectorSimilarityFunction.compare(target, vectorValues.vectorValue(iterator.index()));
-            }
-
-            @Override
-            public boolean advanceExact(int doc) throws IOException {
-                return doc >= iterator.docID() && iterator.docID() != DocIdSetIterator.NO_MORE_DOCS && iterator.advance(doc) == doc;
-            }
-        };
-    }
-
-    @Override
-    public boolean needsScores() {
-        return false;
-    }
-
-    @Override
-    public DoubleValuesSource rewrite(IndexSearcher reader) throws IOException {
-        return this;
-    }
-
-    @Override
-    public int hashCode() {
-        return Objects.hash(field, Arrays.hashCode(target), vectorSimilarityFunction);
-    }
-
-    @Override
-    public boolean equals(Object o) {
-        if (this == o) return true;
-        if (o == null || getClass() != o.getClass()) return false;
-        VectorSimilarityFloatValueSource that = (VectorSimilarityFloatValueSource) o;
-        return Objects.equals(field, that.field)
-            && Arrays.equals(target, that.target)
-            && vectorSimilarityFunction == that.vectorSimilarityFunction;
-    }
-
-    @Override
-    public String toString() {
-        return "VectorSimilarityFloatValueSource(" + field + ", [" + target[0] + ",...], " + vectorSimilarityFunction + ")";
-    }
-
-    @Override
-    public boolean isCacheable(LeafReaderContext ctx) {
-        return false;
-    }
-}

+ 137 - 7
server/src/main/java/org/elasticsearch/search/vectors/RescoreKnnVectorQuery.java

@@ -9,20 +9,32 @@
 
 package org.elasticsearch.search.vectors;
 
+import org.apache.lucene.index.FloatVectorValues;
+import org.apache.lucene.index.KnnVectorValues;
 import org.apache.lucene.index.VectorSimilarityFunction;
 import org.apache.lucene.queries.function.FunctionScoreQuery;
 import org.apache.lucene.search.BooleanClause;
+import org.apache.lucene.search.ConjunctionUtils;
+import org.apache.lucene.search.DocAndFloatFeatureBuffer;
+import org.apache.lucene.search.DocIdSetIterator;
 import org.apache.lucene.search.IndexSearcher;
 import org.apache.lucene.search.KnnByteVectorQuery;
 import org.apache.lucene.search.KnnFloatVectorQuery;
+import org.apache.lucene.search.MatchAllDocsQuery;
+import org.apache.lucene.search.MatchNoDocsQuery;
 import org.apache.lucene.search.Query;
 import org.apache.lucene.search.QueryVisitor;
+import org.apache.lucene.search.ScoreDoc;
+import org.apache.lucene.search.ScoreMode;
 import org.apache.lucene.search.TopDocs;
-import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource;
+import org.elasticsearch.index.codec.vectors.BulkScorableFloatVectorValues;
+import org.elasticsearch.index.codec.vectors.BulkScorableVectorValues;
 import org.elasticsearch.search.profile.query.QueryProfiler;
 
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.List;
 import java.util.Objects;
 
 /**
@@ -159,10 +171,8 @@ public abstract class RescoreKnnVectorQuery extends Query implements QueryProfil
 
         @Override
         public Query rewrite(IndexSearcher searcher) throws IOException {
-            var valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction);
-            var functionScoreQuery = new FunctionScoreQuery(innerQuery, valueSource);
-            // Retrieve top k documents from the function score query
-            var topDocs = searcher.search(functionScoreQuery, k);
+            var rescoreQuery = new DirectRescoreKnnVectorQuery(fieldName, floatTarget, innerQuery);
+            var topDocs = searcher.search(rescoreQuery, k);
             vectorOperations = topDocs.totalHits.value();
             return new KnnScoreDocQuery(topDocs.scoreDocs, searcher.getIndexReader());
         }
@@ -204,8 +214,7 @@ public abstract class RescoreKnnVectorQuery extends Query implements QueryProfil
 
             // Retrieve top `k` documents from the top `rescoreK` query
             var topDocsQuery = new KnnScoreDocQuery(topDocs.scoreDocs, searcher.getIndexReader());
-            var valueSource = new VectorSimilarityFloatValueSource(fieldName, floatTarget, vectorSimilarityFunction);
-            var rescoreQuery = new FunctionScoreQuery(topDocsQuery, valueSource);
+            var rescoreQuery = new DirectRescoreKnnVectorQuery(fieldName, floatTarget, topDocsQuery);
             var rescoreTopDocs = searcher.search(rescoreQuery.rewrite(searcher), k);
             return new KnnScoreDocQuery(rescoreTopDocs.scoreDocs, searcher.getIndexReader());
         }
@@ -223,4 +232,125 @@ public abstract class RescoreKnnVectorQuery extends Query implements QueryProfil
             return Objects.hash(super.hashCode(), rescoreK);
         }
     }
+
+    private static class DirectRescoreKnnVectorQuery extends Query {
+        private final float[] floatTarget;
+        private final String fieldName;
+        private final Query innerQuery;
+
+        DirectRescoreKnnVectorQuery(String fieldName, float[] floatTarget, Query innerQuery) {
+            this.fieldName = fieldName;
+            this.floatTarget = floatTarget;
+            this.innerQuery = innerQuery;
+        }
+
+        @Override
+        public String toString(String field) {
+            return "DirectRescoreKnnVectorQuery[" + innerQuery + "]";
+        }
+
+        @Override
+        public Query rewrite(IndexSearcher indexSearcher) throws IOException {
+            Query innerRewritten = innerQuery.rewrite(indexSearcher);
+            if (innerRewritten.getClass() == MatchNoDocsQuery.class) {
+                return new MatchNoDocsQuery();
+            }
+            assert innerRewritten.getClass() != MatchAllDocsQuery.class;
+
+            List<ScoreDoc> results = new ArrayList<>(10);
+            for (var leaf : indexSearcher.getIndexReader().leaves()) {
+                var knnVectorValues = leaf.reader().getFloatVectorValues(fieldName);
+                if (knnVectorValues == null) {
+                    continue;
+                }
+                if (knnVectorValues.dimension() != floatTarget.length) {
+                    throw new IllegalArgumentException(
+                        "vector query dimension: " + floatTarget.length + " differs from field dimension: " + knnVectorValues.dimension()
+                    );
+                }
+                var weight = innerRewritten.createWeight(indexSearcher, ScoreMode.COMPLETE_NO_SCORES, 1.0f);
+                var scorer = weight.scorer(leaf);
+                if (scorer == null) {
+                    continue;
+                }
+                var filterIterator = scorer.iterator();
+                if (knnVectorValues instanceof BulkScorableFloatVectorValues rescorableVectorValues) {
+                    rescoreBulk(leaf.docBase, rescorableVectorValues, results, filterIterator);
+                } else {
+                    rescoreIndividually(
+                        leaf.docBase,
+                        knnVectorValues,
+                        leaf.reader().getFieldInfos().fieldInfo(fieldName).getVectorSimilarityFunction(),
+                        results,
+                        filterIterator
+                    );
+                }
+            }
+            // Remove any remaining sentinel values
+            ScoreDoc[] arrayResults = results.toArray(new ScoreDoc[0]);
+            return new KnnScoreDocQuery(arrayResults, indexSearcher.getIndexReader());
+        }
+
+        @Override
+        public void visit(QueryVisitor visitor) {
+            if (visitor.acceptField(fieldName)) {
+                visitor.visitLeaf(this);
+            }
+        }
+
+        @Override
+        public boolean equals(Object obj) {
+            if (this == obj) return true;
+            if (obj == null || getClass() != obj.getClass()) return false;
+            DirectRescoreKnnVectorQuery that = (DirectRescoreKnnVectorQuery) obj;
+            return Objects.equals(innerQuery, that.innerQuery);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(innerQuery, getClass());
+        }
+
+        private void rescoreBulk(
+            int docBase,
+            BulkScorableFloatVectorValues rescorableVectorValues,
+            List<ScoreDoc> queue,
+            DocIdSetIterator filterIterator
+        ) throws IOException {
+            BulkScorableVectorValues.BulkVectorScorer vectorReScorer = rescorableVectorValues.rescorer(floatTarget);
+            var iterator = vectorReScorer.iterator();
+            BulkScorableVectorValues.BulkVectorScorer.Bulk bulkScorer = vectorReScorer.bulk(filterIterator);
+            DocAndFloatFeatureBuffer buffer = new DocAndFloatFeatureBuffer();
+            while (iterator.docID() != DocIdSetIterator.NO_MORE_DOCS) {
+                // iterator already takes live docs into account
+                bulkScorer.nextDocsAndScores(64, null, buffer);
+                for (int i = 0; i < buffer.size; i++) {
+                    float score = buffer.features[i];
+                    int doc = buffer.docs[i];
+                    queue.add(new ScoreDoc(doc + docBase, score));
+                }
+            }
+        }
+
+        private void rescoreIndividually(
+            int docBase,
+            FloatVectorValues knnVectorValues,
+            VectorSimilarityFunction function,
+            List<ScoreDoc> queue,
+            DocIdSetIterator filterIterator
+        ) throws IOException {
+            int doc;
+            KnnVectorValues.DocIndexIterator knnVectorIterator = knnVectorValues.iterator();
+            var conjunction = ConjunctionUtils.intersectIterators(List.of(knnVectorIterator, filterIterator));
+            while ((doc = conjunction.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
+                assert doc == knnVectorIterator.docID();
+                float[] vector = knnVectorValues.vectorValue(knnVectorIterator.index());
+                float score = function.compare(floatTarget, vector);
+                if (Float.isNaN(score)) {
+                    continue;
+                }
+                queue.add(new ScoreDoc(doc + docBase, score));
+            }
+        }
+    }
 }

+ 3 - 3
server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java

@@ -22,6 +22,7 @@ import org.apache.lucene.search.BooleanClause;
 import org.apache.lucene.search.BooleanQuery;
 import org.apache.lucene.search.DoubleValuesSource;
 import org.apache.lucene.search.FieldExistsQuery;
+import org.apache.lucene.search.FullPrecisionFloatVectorSimilarityValuesSource;
 import org.apache.lucene.search.IndexSearcher;
 import org.apache.lucene.search.KnnFloatVectorQuery;
 import org.apache.lucene.search.MatchAllDocsQuery;
@@ -38,7 +39,6 @@ import org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsForm
 import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat;
 import org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat;
 import org.elasticsearch.index.codec.zstd.Zstd814StoredFieldsFormat;
-import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource;
 import org.elasticsearch.search.profile.query.QueryProfiler;
 import org.elasticsearch.test.ESTestCase;
 
@@ -89,9 +89,9 @@ public class RescoreKnnVectorQueryTests extends ESTestCase {
                     assertThat(rescoredDocs.scoreDocs.length, equalTo(k));
 
                     // Get real scores
-                    DoubleValuesSource valueSource = new VectorSimilarityFloatValueSource(
-                        FIELD_NAME,
+                    DoubleValuesSource valueSource = new FullPrecisionFloatVectorSimilarityValuesSource(
                         queryVector,
+                        FIELD_NAME,
                         VectorSimilarityFunction.COSINE
                     );
                     FunctionScoreQuery functionScoreQuery = new FunctionScoreQuery(new MatchAllDocsQuery(), valueSource);