Browse Source

Only collect bulk scored vectors when exceeding min competitive (#132293)

We should not bother collecting vectors that are not competitive. This
PR adjusts the scoring interfaces to include the `max` score returned
from the block of scored vectors. Then, we will attempt to collect that
block if the max score of that block is competitive. 

This gives a nice speed improvement when querying many probes.
Benjamin Trent 2 months ago
parent
commit
47395ffee0

+ 6 - 1
libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91OSQVectorsScorer.java

@@ -141,7 +141,7 @@ public class ES91OSQVectorsScorer {
      *
      * <p>The results are stored in the provided scores array.
      */
-    public void scoreBulk(
+    public float scoreBulk(
         byte[] q,
         float queryLowerInterval,
         float queryUpperInterval,
@@ -158,6 +158,7 @@ public class ES91OSQVectorsScorer {
             targetComponentSums[i] = Short.toUnsignedInt(in.readShort());
         }
         in.readFloats(additionalCorrections, 0, BULK_SIZE);
+        float maxScore = Float.NEGATIVE_INFINITY;
         for (int i = 0; i < BULK_SIZE; i++) {
             scores[i] = score(
                 queryLowerInterval,
@@ -172,6 +173,10 @@ public class ES91OSQVectorsScorer {
                 additionalCorrections[i],
                 scores[i]
             );
+            if (scores[i] > maxScore) {
+                maxScore = scores[i];
+            }
         }
+        return maxScore;
     }
 }

+ 16 - 8
libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java

@@ -352,7 +352,7 @@ public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScore
     }
 
     @Override
-    public void scoreBulk(
+    public float scoreBulk(
         byte[] q,
         float queryLowerInterval,
         float queryUpperInterval,
@@ -366,7 +366,7 @@ public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScore
         // 128 / 8 == 16
         if (length >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) {
             if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) {
-                score256Bulk(
+                return score256Bulk(
                     q,
                     queryLowerInterval,
                     queryUpperInterval,
@@ -376,9 +376,8 @@ public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScore
                     centroidDp,
                     scores
                 );
-                return;
             } else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) {
-                score128Bulk(
+                return score128Bulk(
                     q,
                     queryLowerInterval,
                     queryUpperInterval,
@@ -388,10 +387,9 @@ public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScore
                     centroidDp,
                     scores
                 );
-                return;
             }
         }
-        super.scoreBulk(
+        return super.scoreBulk(
             q,
             queryLowerInterval,
             queryUpperInterval,
@@ -403,7 +401,7 @@ public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScore
         );
     }
 
-    private void score128Bulk(
+    private float score128Bulk(
         byte[] q,
         float queryLowerInterval,
         float queryUpperInterval,
@@ -420,6 +418,7 @@ public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScore
         float ay = queryLowerInterval;
         float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE;
         float y1 = queryComponentSum;
+        float maxScore = Float.NEGATIVE_INFINITY;
         for (; i < limit; i += FLOAT_SPECIES_128.length()) {
             var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_128, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
             var lx = FloatVector.fromMemorySegment(
@@ -453,6 +452,7 @@ public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScore
             if (similarityFunction == EUCLIDEAN) {
                 res = res.mul(-2).add(additionalCorrections).add(queryAdditionalCorrection).add(1f);
                 res = FloatVector.broadcast(FLOAT_SPECIES_128, 1).div(res).max(0);
+                maxScore = res.reduceLanes(VectorOperators.MAX);
                 res.intoArray(scores, i);
             } else {
                 // For cosine and max inner product, we need to apply the additional correction, which is
@@ -463,17 +463,20 @@ public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScore
                     // not sure how to do it better
                     for (int j = 0; j < FLOAT_SPECIES_128.length(); j++) {
                         scores[i + j] = VectorUtil.scaleMaxInnerProductScore(scores[i + j]);
+                        maxScore = Math.max(maxScore, scores[i + j]);
                     }
                 } else {
                     res = res.add(1f).mul(0.5f).max(0);
                     res.intoArray(scores, i);
+                    maxScore = res.reduceLanes(VectorOperators.MAX);
                 }
             }
         }
         in.seek(offset + 14L * BULK_SIZE);
+        return maxScore;
     }
 
-    private void score256Bulk(
+    private float score256Bulk(
         byte[] q,
         float queryLowerInterval,
         float queryUpperInterval,
@@ -490,6 +493,7 @@ public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScore
         float ay = queryLowerInterval;
         float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE;
         float y1 = queryComponentSum;
+        float maxScore = Float.NEGATIVE_INFINITY;
         for (; i < limit; i += FLOAT_SPECIES_256.length()) {
             var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_256, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
             var lx = FloatVector.fromMemorySegment(
@@ -523,6 +527,7 @@ public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScore
             if (similarityFunction == EUCLIDEAN) {
                 res = res.mul(-2).add(additionalCorrections).add(queryAdditionalCorrection).add(1f);
                 res = FloatVector.broadcast(FLOAT_SPECIES_256, 1).div(res).max(0);
+                maxScore = res.reduceLanes(VectorOperators.MAX);
                 res.intoArray(scores, i);
             } else {
                 // For cosine and max inner product, we need to apply the additional correction, which is
@@ -533,13 +538,16 @@ public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScore
                     // not sure how to do it better
                     for (int j = 0; j < FLOAT_SPECIES_256.length(); j++) {
                         scores[i + j] = VectorUtil.scaleMaxInnerProductScore(scores[i + j]);
+                        maxScore = Math.max(maxScore, scores[i + j]);
                     }
                 } else {
                     res = res.add(1f).mul(0.5f).max(0);
+                    maxScore = res.reduceLanes(VectorOperators.MAX);
                     res.intoArray(scores, i);
                 }
             }
         }
         in.seek(offset + 14L * BULK_SIZE);
+        return maxScore;
     }
 }

+ 35 - 17
server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java

@@ -372,7 +372,8 @@ public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeap
             return vectors;
         }
 
-        void scoreIndividually(int offset) throws IOException {
+        float scoreIndividually(int offset) throws IOException {
+            float maxScore = Float.NEGATIVE_INFINITY;
             // score individually, first the quantized byte chunk
             for (int j = 0; j < BULK_SIZE; j++) {
                 int doc = docIdsScratch[j + offset];
@@ -407,8 +408,35 @@ public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeap
                         correctionsAdd[j],
                         scores[j]
                     );
+                    if (scores[j] > maxScore) {
+                        maxScore = scores[j];
+                    }
+                }
+            }
+            return maxScore;
+        }
+
+        private static int filterDocs(int[] docIds, int offset, IntPredicate needsScoring) {
+            int filtered = 0;
+            for (int i = 0; i < ES91OSQVectorsScorer.BULK_SIZE; i++) {
+                if (needsScoring.test(docIds[offset + i]) == false) {
+                    docIds[offset + i] = -1;
+                    filtered++;
+                }
+            }
+            return filtered;
+        }
+
+        private static int collect(int[] docIds, int offset, KnnCollector knnCollector, float[] scores) {
+            int scoredDocs = 0;
+            for (int i = 0; i < ES91OSQVectorsScorer.BULK_SIZE; i++) {
+                int doc = docIds[offset + i];
+                if (doc != -1) {
+                    scoredDocs++;
+                    knnCollector.collect(doc, scores[i]);
                 }
             }
+            return scoredDocs;
         }
 
         @Override
@@ -418,23 +446,17 @@ public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeap
             int limit = vectors - BULK_SIZE + 1;
             int i = 0;
             for (; i < limit; i += BULK_SIZE) {
-                int docsToScore = BULK_SIZE;
-                for (int j = 0; j < BULK_SIZE; j++) {
-                    int doc = docIdsScratch[i + j];
-                    if (needsScoring.test(doc) == false) {
-                        docIdsScratch[i + j] = -1;
-                        docsToScore--;
-                    }
-                }
+                int docsToScore = BULK_SIZE - filterDocs(docIdsScratch, i, needsScoring);
                 if (docsToScore == 0) {
                     continue;
                 }
                 quantizeQueryIfNecessary();
                 indexInput.seek(slicePos + i * quantizedByteLength);
+                float maxScore = Float.NEGATIVE_INFINITY;
                 if (docsToScore < BULK_SIZE / 2) {
-                    scoreIndividually(i);
+                    maxScore = scoreIndividually(i);
                 } else {
-                    osqVectorsScorer.scoreBulk(
+                    maxScore = osqVectorsScorer.scoreBulk(
                         quantizedQueryScratch,
                         queryCorrections.lowerInterval(),
                         queryCorrections.upperInterval(),
@@ -445,12 +467,8 @@ public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeap
                         scores
                     );
                 }
-                for (int j = 0; j < BULK_SIZE; j++) {
-                    int doc = docIdsScratch[i + j];
-                    if (doc != -1) {
-                        scoredDocs++;
-                        knnCollector.collect(doc, scores[j]);
-                    }
+                if (knnCollector.minCompetitiveSimilarity() < maxScore) {
+                    scoredDocs += collect(docIdsScratch, i, knnCollector, scores);
                 }
             }
             // process tail