Browse Source

Fix max score calculation in MemorySegmentES91OSQVectorsScorer (#132433)

There is an error in how we compute max score in our panamized version of ES91OSQVectorsScorer after #132293. 
This commit fixes it and increases test coverage.
Ignacio Vera 2 months ago
parent
commit
63e2a3b0eb

+ 4 - 4
libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java

@@ -452,7 +452,7 @@ public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScore
             if (similarityFunction == EUCLIDEAN) {
                 res = res.mul(-2).add(additionalCorrections).add(queryAdditionalCorrection).add(1f);
                 res = FloatVector.broadcast(FLOAT_SPECIES_128, 1).div(res).max(0);
-                maxScore = res.reduceLanes(VectorOperators.MAX);
+                maxScore = Math.max(maxScore, res.reduceLanes(VectorOperators.MAX));
                 res.intoArray(scores, i);
             } else {
                 // For cosine and max inner product, we need to apply the additional correction, which is
@@ -468,7 +468,7 @@ public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScore
                 } else {
                     res = res.add(1f).mul(0.5f).max(0);
                     res.intoArray(scores, i);
-                    maxScore = res.reduceLanes(VectorOperators.MAX);
+                    maxScore = Math.max(maxScore, res.reduceLanes(VectorOperators.MAX));
                 }
             }
         }
@@ -527,7 +527,7 @@ public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScore
             if (similarityFunction == EUCLIDEAN) {
                 res = res.mul(-2).add(additionalCorrections).add(queryAdditionalCorrection).add(1f);
                 res = FloatVector.broadcast(FLOAT_SPECIES_256, 1).div(res).max(0);
-                maxScore = res.reduceLanes(VectorOperators.MAX);
+                maxScore = Math.max(maxScore, res.reduceLanes(VectorOperators.MAX));
                 res.intoArray(scores, i);
             } else {
                 // For cosine and max inner product, we need to apply the additional correction, which is
@@ -542,7 +542,7 @@ public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScore
                     }
                 } else {
                     res = res.add(1f).mul(0.5f).max(0);
-                    maxScore = res.reduceLanes(VectorOperators.MAX);
+                    maxScore = Math.max(maxScore, res.reduceLanes(VectorOperators.MAX));
                     res.intoArray(scores, i);
                 }
             }

+ 195 - 85
libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java

@@ -15,16 +15,20 @@ import org.apache.lucene.store.IOContext;
 import org.apache.lucene.store.IndexInput;
 import org.apache.lucene.store.IndexOutput;
 import org.apache.lucene.store.MMapDirectory;
-import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
+import org.apache.lucene.util.VectorUtil;
+import org.elasticsearch.index.codec.vectors.BQSpaceUtils;
+import org.elasticsearch.index.codec.vectors.BQVectorUtils;
+import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
+import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
 import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
 
-import static org.hamcrest.Matchers.lessThan;
+import java.io.IOException;
 
 public class ES91OSQVectorScorerTests extends BaseVectorizationTests {
 
     public void testQuantizeScore() throws Exception {
         final int dimensions = random().nextInt(1, 2000);
-        final int length = OptimizedScalarQuantizer.discretize(dimensions, 64) / 8;
+        final int length = BQVectorUtils.discretize(dimensions, 64) / 8;
         final int numVectors = random().nextInt(1, 100);
         final byte[] vector = new byte[length];
         try (Directory dir = new MMapDirectory(createTempDir())) {
@@ -53,102 +57,208 @@ public class ES91OSQVectorScorerTests extends BaseVectorizationTests {
     }
 
     public void testScore() throws Exception {
-        final int maxDims = 512;
+        final int maxDims = random().nextInt(1, 1000) * 2;
         final int dimensions = random().nextInt(1, maxDims);
-        final int length = OptimizedScalarQuantizer.discretize(dimensions, 64) / 8;
-        final int numVectors = ES91OSQVectorsScorer.BULK_SIZE * random().nextInt(1, 10);
-        final byte[] vector = new byte[length];
+        final int length = BQVectorUtils.discretize(dimensions, 64) / 8;
+        final int numVectors = random().nextInt(10, 50);
+        float[][] vectors = new float[numVectors][dimensions];
+        final int[] scratch = new int[dimensions];
+        final byte[] qVector = new byte[length];
+        final float[] centroid = new float[dimensions];
+        VectorSimilarityFunction similarityFunction = randomFrom(VectorSimilarityFunction.values());
+        randomVector(centroid, similarityFunction);
+        OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(similarityFunction);
         int padding = random().nextInt(100);
         byte[] paddingBytes = new byte[padding];
         try (Directory dir = new MMapDirectory(createTempDir())) {
             try (IndexOutput out = dir.createOutput("testScore.bin", IOContext.DEFAULT)) {
                 random().nextBytes(paddingBytes);
                 out.writeBytes(paddingBytes, 0, padding);
+                for (float[] vector : vectors) {
+                    randomVector(vector, similarityFunction);
+                    OptimizedScalarQuantizer.QuantizationResult result = quantizer.scalarQuantize(
+                        vector.clone(),
+                        scratch,
+                        (byte) 1,
+                        centroid
+                    );
+                    BQVectorUtils.packAsBinary(scratch, qVector);
+                    out.writeBytes(qVector, 0, qVector.length);
+                    out.writeInt(Float.floatToIntBits(result.lowerInterval()));
+                    out.writeInt(Float.floatToIntBits(result.upperInterval()));
+                    out.writeInt(Float.floatToIntBits(result.additionalCorrection()));
+                    out.writeShort((short) result.quantizedComponentSum());
+                }
+            }
+            final float[] query = new float[dimensions];
+            randomVector(query, similarityFunction);
+            OptimizedScalarQuantizer.QuantizationResult queryCorrections = quantizer.scalarQuantize(
+                query.clone(),
+                scratch,
+                (byte) 4,
+                centroid
+            );
+            final byte[] quantizeQuery = new byte[4 * length];
+            BQSpaceUtils.transposeHalfByte(scratch, quantizeQuery);
+            final float centroidDp = VectorUtil.dotProduct(centroid, centroid);
+            final float[] floatScratch = new float[3];
+            try (IndexInput in = dir.openInput("testScore.bin", IOContext.DEFAULT)) {
+                in.seek(padding);
+                assertEquals(in.length(), padding + (long) numVectors * (length + 14));
+                final IndexInput slice = in.slice("test", in.getFilePointer(), (long) (length + 14) * numVectors);
+                // Work on a slice that has just the right number of bytes to make the test fail with an
+                // index-out-of-bounds in case the implementation reads more than the allowed number of
+                // padding bytes.
                 for (int i = 0; i < numVectors; i++) {
-                    random().nextBytes(vector);
-                    out.writeBytes(vector, 0, length);
-                    float lower = random().nextFloat();
-                    float upper = random().nextFloat() + lower / 2;
-                    float additionalCorrection = random().nextFloat();
-                    int targetComponentSum = randomIntBetween(0, dimensions / 2);
-                    out.writeInt(Float.floatToIntBits(lower));
-                    out.writeInt(Float.floatToIntBits(upper));
-                    out.writeShort((short) targetComponentSum);
-                    out.writeInt(Float.floatToIntBits(additionalCorrection));
+                    final ES91OSQVectorsScorer defaultScorer = defaultProvider().newES91OSQVectorsScorer(slice, dimensions);
+                    final ES91OSQVectorsScorer panamaScorer = maybePanamaProvider().newES91OSQVectorsScorer(in, dimensions);
+                    long qDist = defaultScorer.quantizeScore(quantizeQuery);
+                    slice.readFloats(floatScratch, 0, 3);
+                    int quantizedComponentSum = slice.readShort();
+                    float defaulScore = defaultScorer.score(
+                        queryCorrections.lowerInterval(),
+                        queryCorrections.upperInterval(),
+                        queryCorrections.quantizedComponentSum(),
+                        queryCorrections.additionalCorrection(),
+                        similarityFunction,
+                        centroidDp,
+                        floatScratch[0],
+                        floatScratch[1],
+                        quantizedComponentSum,
+                        floatScratch[2],
+                        qDist
+                    );
+                    qDist = panamaScorer.quantizeScore(quantizeQuery);
+                    in.readFloats(floatScratch, 0, 3);
+                    quantizedComponentSum = in.readShort();
+                    float panamaScore = panamaScorer.score(
+                        queryCorrections.lowerInterval(),
+                        queryCorrections.upperInterval(),
+                        queryCorrections.quantizedComponentSum(),
+                        queryCorrections.additionalCorrection(),
+                        similarityFunction,
+                        centroidDp,
+                        floatScratch[0],
+                        floatScratch[1],
+                        quantizedComponentSum,
+                        floatScratch[2],
+                        qDist
+                    );
+                    assertEquals(defaulScore, panamaScore, 1e-2f);
+                    assertEquals(((long) (i + 1) * (length + 14)), slice.getFilePointer());
+                    assertEquals(padding + ((long) (i + 1) * (length + 14)), in.getFilePointer());
                 }
             }
-            final byte[] query = new byte[4 * length];
-            random().nextBytes(query);
-            float lower = random().nextFloat();
-            OptimizedScalarQuantizer.QuantizationResult result = new OptimizedScalarQuantizer.QuantizationResult(
-                lower,
-                random().nextFloat() + lower / 2,
-                random().nextFloat(),
-                randomIntBetween(0, dimensions * 2)
+        }
+    }
+
+    public void testScoreBulk() throws Exception {
+        final int maxDims = random().nextInt(1, 1000) * 2;
+        final int dimensions = random().nextInt(1, maxDims);
+        final int length = BQVectorUtils.discretize(dimensions, 64) / 8;
+        final int numVectors = ES91OSQVectorsScorer.BULK_SIZE * random().nextInt(1, 10);
+        float[][] vectors = new float[numVectors][dimensions];
+        final int[] scratch = new int[dimensions];
+        final byte[] qVector = new byte[length];
+        final float[] centroid = new float[dimensions];
+        VectorSimilarityFunction similarityFunction = randomFrom(VectorSimilarityFunction.values());
+        randomVector(centroid, similarityFunction);
+        OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(similarityFunction);
+        int padding = random().nextInt(100);
+        byte[] paddingBytes = new byte[padding];
+        try (Directory dir = new MMapDirectory(createTempDir())) {
+            try (IndexOutput out = dir.createOutput("testScore.bin", IOContext.DEFAULT)) {
+                random().nextBytes(paddingBytes);
+                out.writeBytes(paddingBytes, 0, padding);
+                int limit = numVectors - ES91OSQVectorsScorer.BULK_SIZE + 1;
+                OptimizedScalarQuantizer.QuantizationResult[] results =
+                    new OptimizedScalarQuantizer.QuantizationResult[ES91Int4VectorsScorer.BULK_SIZE];
+                for (int i = 0; i < limit; i += ES91OSQVectorsScorer.BULK_SIZE) {
+                    for (int j = 0; j < ES91Int4VectorsScorer.BULK_SIZE; j++) {
+                        randomVector(vectors[i + j], similarityFunction);
+                        results[j] = quantizer.scalarQuantize(vectors[i + j].clone(), scratch, (byte) 1, centroid);
+                        BQVectorUtils.packAsBinary(scratch, qVector);
+                        out.writeBytes(qVector, 0, qVector.length);
+                    }
+                    writeCorrections(results, out);
+                }
+            }
+            final float[] query = new float[dimensions];
+            randomVector(query, similarityFunction);
+            OptimizedScalarQuantizer.QuantizationResult queryCorrections = quantizer.scalarQuantize(
+                query.clone(),
+                scratch,
+                (byte) 4,
+                centroid
             );
-            final float centroidDp = random().nextFloat();
-            final float[] scores1 = new float[ES91OSQVectorsScorer.BULK_SIZE];
-            final float[] scores2 = new float[ES91OSQVectorsScorer.BULK_SIZE];
-            for (VectorSimilarityFunction similarityFunction : VectorSimilarityFunction.values()) {
-                try (IndexInput in = dir.openInput("testScore.bin", IOContext.DEFAULT)) {
-                    in.seek(padding);
-                    assertEquals(in.length(), padding + (long) numVectors * (length + 14));
-                    // Work on a slice that has just the right number of bytes to make the test fail with an
-                    // index-out-of-bounds in case the implementation reads more than the allowed number of
-                    // padding bytes.
-                    for (int i = 0; i < numVectors; i += ES91OSQVectorsScorer.BULK_SIZE) {
-                        final IndexInput slice = in.slice(
-                            "test",
-                            in.getFilePointer(),
-                            (long) (length + 14) * ES91OSQVectorsScorer.BULK_SIZE
-                        );
-                        final ES91OSQVectorsScorer defaultScorer = defaultProvider().newES91OSQVectorsScorer(slice, dimensions);
-                        final ES91OSQVectorsScorer panamaScorer = maybePanamaProvider().newES91OSQVectorsScorer(in, dimensions);
-                        defaultScorer.scoreBulk(
-                            query,
-                            result.lowerInterval(),
-                            result.upperInterval(),
-                            result.quantizedComponentSum(),
-                            result.additionalCorrection(),
-                            similarityFunction,
-                            centroidDp,
-                            scores1
-                        );
-                        panamaScorer.scoreBulk(
-                            query,
-                            result.lowerInterval(),
-                            result.upperInterval(),
-                            result.quantizedComponentSum(),
-                            result.additionalCorrection(),
-                            similarityFunction,
-                            centroidDp,
-                            scores2
-                        );
-                        for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
-                            if (scores1[j] == scores2[j]) {
-                                continue;
-                            }
-                            if (scores1[j] > (maxDims * Byte.MAX_VALUE)) {
-                                float diff = Math.abs(scores1[j] - scores2[j]);
-                                assertThat(
-                                    "defaultScores: " + scores1[j] + " bulkScores: " + scores2[j],
-                                    diff / scores1[j],
-                                    lessThan(1e-5f)
-                                );
-                                assertThat(
-                                    "defaultScores: " + scores1[j] + " bulkScores: " + scores2[j],
-                                    diff / scores2[j],
-                                    lessThan(1e-5f)
-                                );
-                            } else {
-                                assertEquals(scores1[j], scores2[j], 1e-2f);
-                            }
-                        }
-                        assertEquals(((long) (ES91OSQVectorsScorer.BULK_SIZE) * (length + 14)), slice.getFilePointer());
-                        assertEquals(padding + ((long) (i + ES91OSQVectorsScorer.BULK_SIZE) * (length + 14)), in.getFilePointer());
+            final byte[] quantizeQuery = new byte[4 * length];
+            BQSpaceUtils.transposeHalfByte(scratch, quantizeQuery);
+            final float centroidDp = VectorUtil.dotProduct(centroid, centroid);
+            final float[] scoresDefault = new float[ES91OSQVectorsScorer.BULK_SIZE];
+            final float[] scoresPanama = new float[ES91OSQVectorsScorer.BULK_SIZE];
+            try (IndexInput in = dir.openInput("testScore.bin", IOContext.DEFAULT)) {
+                in.seek(padding);
+                assertEquals(in.length(), padding + (long) numVectors * (length + 14));
+                // Work on a slice that has just the right number of bytes to make the test fail with an
+                // index-out-of-bounds in case the implementation reads more than the allowed number of
+                // padding bytes.
+                for (int i = 0; i < numVectors; i += ES91OSQVectorsScorer.BULK_SIZE) {
+                    final IndexInput slice = in.slice("test", in.getFilePointer(), (long) (length + 14) * ES91OSQVectorsScorer.BULK_SIZE);
+                    final ES91OSQVectorsScorer defaultScorer = defaultProvider().newES91OSQVectorsScorer(slice, dimensions);
+                    final ES91OSQVectorsScorer panamaScorer = maybePanamaProvider().newES91OSQVectorsScorer(in, dimensions);
+                    float defaultMaxScore = defaultScorer.scoreBulk(
+                        quantizeQuery,
+                        queryCorrections.lowerInterval(),
+                        queryCorrections.upperInterval(),
+                        queryCorrections.quantizedComponentSum(),
+                        queryCorrections.additionalCorrection(),
+                        similarityFunction,
+                        centroidDp,
+                        scoresDefault
+                    );
+                    float panamaMaxScore = panamaScorer.scoreBulk(
+                        quantizeQuery,
+                        queryCorrections.lowerInterval(),
+                        queryCorrections.upperInterval(),
+                        queryCorrections.quantizedComponentSum(),
+                        queryCorrections.additionalCorrection(),
+                        similarityFunction,
+                        centroidDp,
+                        scoresPanama
+                    );
+                    assertEquals(defaultMaxScore, panamaMaxScore, 1e-2f);
+                    for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
+                        assertEquals(scoresDefault[j], scoresPanama[j], 1e-2f);
                     }
+                    assertEquals(((long) (ES91OSQVectorsScorer.BULK_SIZE) * (length + 14)), slice.getFilePointer());
+                    assertEquals(padding + ((long) (i + ES91OSQVectorsScorer.BULK_SIZE) * (length + 14)), in.getFilePointer());
                 }
             }
         }
     }
+
+    private static void writeCorrections(OptimizedScalarQuantizer.QuantizationResult[] corrections, IndexOutput out) throws IOException {
+        for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
+            out.writeInt(Float.floatToIntBits(correction.lowerInterval()));
+        }
+        for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
+            out.writeInt(Float.floatToIntBits(correction.upperInterval()));
+        }
+        for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
+            int targetComponentSum = correction.quantizedComponentSum();
+            out.writeShort((short) targetComponentSum);
+        }
+        for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) {
+            out.writeInt(Float.floatToIntBits(correction.additionalCorrection()));
+        }
+    }
+
+    private void randomVector(float[] vector, VectorSimilarityFunction vectorSimilarityFunction) {
+        for (int i = 0; i < vector.length; i++) {
+            vector[i] = random().nextFloat();
+        }
+        if (vectorSimilarityFunction != VectorSimilarityFunction.EUCLIDEAN) {
+            VectorUtil.l2normalize(vector);
+        }
+    }
 }