Browse Source

Speed up OptimizedScalarQuantizer (#131599)

use the destination array to keep the quantize value during the loss computation and give to the
method computing the grid points
Ignacio Vera 2 months ago
parent
commit
4468239dee

+ 3 - 12
benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OptimizedScalarQuantizerBenchmark.java

@@ -43,7 +43,6 @@ public class OptimizedScalarQuantizerBenchmark {
 
     float[] vector;
     float[] centroid;
-    byte[] legacyDestination;
     int[] destination;
 
     @Param({ "1", "4", "7" })
@@ -55,7 +54,6 @@ public class OptimizedScalarQuantizerBenchmark {
     public void init() {
         ThreadLocalRandom random = ThreadLocalRandom.current();
         // random byte arrays for binary methods
-        legacyDestination = new byte[dims];
         destination = new int[dims];
         vector = new float[dims];
         centroid = new float[dims];
@@ -66,16 +64,9 @@ public class OptimizedScalarQuantizerBenchmark {
     }
 
     @Benchmark
-    public byte[] scalar() {
-        osq.legacyScalarQuantize(vector, legacyDestination, bits, centroid);
-        return legacyDestination;
-    }
-
-    @Benchmark
-    @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
-    public byte[] legacyVector() {
-        osq.legacyScalarQuantize(vector, legacyDestination, bits, centroid);
-        return legacyDestination;
+    public int[] scalar() {
+        osq.scalarQuantize(vector, destination, bits, centroid);
+        return destination;
     }
 
     @Benchmark

+ 5 - 0
docs/changelog/131599.yaml

@@ -0,0 +1,5 @@
+pr: 131599
+summary: Speed up `OptimizedScalarQuantizer`
+area: Vector Search
+type: enhancement
+issues: []

+ 20 - 10
libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java

@@ -158,31 +158,41 @@ public class ESVectorUtil {
     /**
      * Calculate the loss for optimized-scalar quantization for the given parameteres
      * @param target The vector being quantized, assumed to be centered
-     * @param interval The interval for which to calculate the loss
+     * @param lowerInterval The lower interval value for which to calculate the loss
+     * @param upperInterval The upper interval value for which to calculate the loss
      * @param points the quantization points
      * @param norm2 The norm squared of the target vector
      * @param lambda The lambda parameter for controlling anisotropic loss calculation
+     * @param quantize array to store the computed quantize vector.
+     *
      * @return The loss for the given parameters
      */
-    public static float calculateOSQLoss(float[] target, float[] interval, int points, float norm2, float lambda) {
-        assert interval.length == 2;
-        float step = ((interval[1] - interval[0]) / (points - 1.0F));
+    public static float calculateOSQLoss(
+        float[] target,
+        float lowerInterval,
+        float upperInterval,
+        int points,
+        float norm2,
+        float lambda,
+        int[] quantize
+    ) {
+        assert upperInterval >= lowerInterval;
+        float step = ((upperInterval - lowerInterval) / (points - 1.0F));
         float invStep = 1f / step;
-        return IMPL.calculateOSQLoss(target, interval, step, invStep, norm2, lambda);
+        return IMPL.calculateOSQLoss(target, lowerInterval, upperInterval, step, invStep, norm2, lambda, quantize);
     }
 
     /**
      * Calculate the grid points for optimized-scalar quantization
      * @param target The vector being quantized, assumed to be centered
-     * @param interval The interval for which to calculate the grid points
+     * @param quantize The quantize vector which should have at least the target vector length
      * @param points the quantization points
      * @param pts The array to store the grid points, must be of length 5
      */
-    public static void calculateOSQGridPoints(float[] target, float[] interval, int points, float[] pts) {
-        assert interval.length == 2;
+    public static void calculateOSQGridPoints(float[] target, int[] quantize, int points, float[] pts) {
+        assert target.length <= quantize.length;
         assert pts.length == 5;
-        float invStep = (points - 1.0F) / (interval[1] - interval[0]);
-        IMPL.calculateOSQGridPoints(target, interval, points, invStep, pts);
+        IMPL.calculateOSQGridPoints(target, quantize, points, pts);
     }
 
     /**

+ 22 - 12
libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java

@@ -46,14 +46,25 @@ final class DefaultESVectorUtilSupport implements ESVectorUtilSupport {
     }
 
     @Override
-    public float calculateOSQLoss(float[] target, float[] interval, float step, float invStep, float norm2, float lambda) {
-        float a = interval[0];
-        float b = interval[1];
+    public float calculateOSQLoss(
+        float[] target,
+        float low,
+        float high,
+        float step,
+        float invStep,
+        float norm2,
+        float lambda,
+        int[] quantize
+    ) {
+        float a = low;
+        float b = high;
         float xe = 0f;
         float e = 0f;
-        for (float xi : target) {
+        for (int i = 0; i < target.length; ++i) {
+            float xi = target[i];
             // this is quantizing and then dequantizing the vector
-            float xiq = fma(step, Math.round((Math.min(Math.max(xi, a), b) - a) * invStep), a);
+            quantize[i] = Math.round((Math.min(Math.max(xi, a), b) - a) * invStep);
+            float xiq = fma(step, quantize[i], a);
             // how much does the de-quantized value differ from the original value
             float xiiq = xi - xiq;
             e = fma(xiiq, xiiq, e);
@@ -63,16 +74,15 @@ final class DefaultESVectorUtilSupport implements ESVectorUtilSupport {
     }
 
     @Override
-    public void calculateOSQGridPoints(float[] target, float[] interval, int points, float invStep, float[] pts) {
-        float a = interval[0];
-        float b = interval[1];
+    public void calculateOSQGridPoints(float[] target, int[] quantize, int points, float[] pts) {
         float daa = 0;
         float dab = 0;
         float dbb = 0;
         float dax = 0;
         float dbx = 0;
-        for (float v : target) {
-            float k = Math.round((Math.min(Math.max(v, a), b) - a) * invStep);
+        for (int i = 0; i < target.length; ++i) {
+            float v = target[i];
+            float k = quantize[i];
             float s = k / (points - 1);
             float ms = 1f - s;
             daa = fma(ms, ms, daa);
@@ -273,11 +283,11 @@ final class DefaultESVectorUtilSupport implements ESVectorUtilSupport {
     @Override
     public int quantizeVectorWithIntervals(float[] vector, int[] destination, float lowInterval, float upperInterval, byte bits) {
         float nSteps = ((1 << bits) - 1);
-        float step = (upperInterval - lowInterval) / nSteps;
+        float invStep = nSteps / (upperInterval - lowInterval);
         int sumQuery = 0;
         for (int h = 0; h < vector.length; h++) {
             float xi = Math.min(Math.max(vector[h], lowInterval), upperInterval);
-            int assignment = Math.round((xi - lowInterval) / step);
+            int assignment = Math.round((xi - lowInterval) * invStep);
             sumQuery += assignment;
             destination[h] = assignment;
         }

+ 12 - 3
libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java

@@ -29,9 +29,18 @@ public interface ESVectorUtilSupport {
 
     float ipFloatByte(float[] q, byte[] d);
 
-    float calculateOSQLoss(float[] target, float[] interval, float step, float invStep, float norm2, float lambda);
-
-    void calculateOSQGridPoints(float[] target, float[] interval, int points, float invStep, float[] pts);
+    float calculateOSQLoss(
+        float[] target,
+        float lowerInterval,
+        float upperInterval,
+        float step,
+        float invStep,
+        float norm2,
+        float lambda,
+        int[] quantize
+    );
+
+    void calculateOSQGridPoints(float[] target, int[] quantize, int points, float[] pts);
 
     void centerAndCalculateOSQStatsEuclidean(float[] target, float[] centroid, float[] centered, float[] stats);
 

+ 29 - 26
libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java

@@ -13,7 +13,6 @@ import jdk.incubator.vector.ByteVector;
 import jdk.incubator.vector.FloatVector;
 import jdk.incubator.vector.IntVector;
 import jdk.incubator.vector.LongVector;
-import jdk.incubator.vector.Vector;
 import jdk.incubator.vector.VectorMask;
 import jdk.incubator.vector.VectorOperators;
 import jdk.incubator.vector.VectorShape;
@@ -31,6 +30,7 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
     static final int VECTOR_BITSIZE;
 
     private static final VectorSpecies<Float> FLOAT_SPECIES;
+    private static final VectorSpecies<Integer> INTEGER_SPECIES;
     /** Whether integer vectors can be trusted to actually be fast. */
     static final boolean HAS_FAST_INTEGER_VECTORS;
 
@@ -38,6 +38,7 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
         // default to platform supported bitsize
         VECTOR_BITSIZE = VectorShape.preferredShape().vectorBitSize();
         FLOAT_SPECIES = VectorSpecies.of(float.class, VectorShape.forBitSize(VECTOR_BITSIZE));
+        INTEGER_SPECIES = VectorSpecies.of(int.class, VectorShape.forBitSize(VECTOR_BITSIZE));
 
         // hotspot misses some SSE intrinsics, workaround it
         // to be fair, they do document this thing only works well with AVX2/AVX3 and Neon
@@ -270,36 +271,26 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
     }
 
     @Override
-    public void calculateOSQGridPoints(float[] target, float[] interval, int points, float invStep, float[] pts) {
-        float a = interval[0];
-        float b = interval[1];
+    public void calculateOSQGridPoints(float[] target, int[] quantize, int points, float[] pts) {
         int i = 0;
         float daa = 0;
         float dab = 0;
         float dbb = 0;
         float dax = 0;
         float dbx = 0;
-
-        FloatVector daaVec = FloatVector.zero(FLOAT_SPECIES);
-        FloatVector dabVec = FloatVector.zero(FLOAT_SPECIES);
-        FloatVector dbbVec = FloatVector.zero(FLOAT_SPECIES);
-        FloatVector daxVec = FloatVector.zero(FLOAT_SPECIES);
-        FloatVector dbxVec = FloatVector.zero(FLOAT_SPECIES);
-
         // if the array size is large (> 2x platform vector size), it's worth the overhead to vectorize
         if (target.length > 2 * FLOAT_SPECIES.length()) {
+            FloatVector daaVec = FloatVector.zero(FLOAT_SPECIES);
+            FloatVector dabVec = FloatVector.zero(FLOAT_SPECIES);
+            FloatVector dbbVec = FloatVector.zero(FLOAT_SPECIES);
+            FloatVector daxVec = FloatVector.zero(FLOAT_SPECIES);
+            FloatVector dbxVec = FloatVector.zero(FLOAT_SPECIES);
             FloatVector ones = FloatVector.broadcast(FLOAT_SPECIES, 1f);
             FloatVector pmOnes = FloatVector.broadcast(FLOAT_SPECIES, points - 1f);
             for (; i < FLOAT_SPECIES.loopBound(target.length); i += FLOAT_SPECIES.length()) {
                 FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, target, i);
-                FloatVector vClamped = v.max(a).min(b);
-                Vector<Integer> xiqint = vClamped.sub(a)
-                    .mul(invStep)
-                    // round
-                    .add(0.5f)
-                    .convert(VectorOperators.F2I, 0);
-                FloatVector kVec = xiqint.convert(VectorOperators.I2F, 0).reinterpretAsFloats();
-                FloatVector sVec = kVec.div(pmOnes);
+                FloatVector oVec = IntVector.fromArray(INTEGER_SPECIES, quantize, i).convert(VectorOperators.I2F, 0).reinterpretAsFloats();
+                FloatVector sVec = oVec.div(pmOnes);
                 FloatVector smVec = ones.sub(sVec);
                 daaVec = fma(smVec, smVec, daaVec);
                 dabVec = fma(smVec, sVec, dabVec);
@@ -315,7 +306,7 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
         }
 
         for (; i < target.length; i++) {
-            float k = Math.round((Math.min(Math.max(target[i], a), b) - a) * invStep);
+            float k = quantize[i];
             float s = k / (points - 1);
             float ms = 1f - s;
             daa = fma(ms, ms, daa);
@@ -333,9 +324,18 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
     }
 
     @Override
-    public float calculateOSQLoss(float[] target, float[] interval, float step, float invStep, float norm2, float lambda) {
-        float a = interval[0];
-        float b = interval[1];
+    public float calculateOSQLoss(
+        float[] target,
+        float lowerInterval,
+        float upperInterval,
+        float step,
+        float invStep,
+        float norm2,
+        float lambda,
+        int[] quantize
+    ) {
+        float a = lowerInterval;
+        float b = upperInterval;
         float xe = 0f;
         float e = 0f;
         FloatVector xeVec = FloatVector.zero(FLOAT_SPECIES);
@@ -346,8 +346,10 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
             for (; i < FLOAT_SPECIES.loopBound(target.length); i += FLOAT_SPECIES.length()) {
                 FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, target, i);
                 FloatVector vClamped = v.max(a).min(b);
-                Vector<Integer> xiqint = vClamped.sub(a).mul(invStep).add(0.5f).convert(VectorOperators.F2I, 0);
-                FloatVector xiq = xiqint.convert(VectorOperators.I2F, 0).reinterpretAsFloats().mul(step).add(a);
+                IntVector xiqint = vClamped.sub(a).mul(invStep).add(0.5f).convert(VectorOperators.F2I, 0).reinterpretAsInts();
+                xiqint.intoArray(quantize, i);
+                FloatVector quantizeVec = xiqint.convert(VectorOperators.I2F, 0).reinterpretAsFloats();
+                FloatVector xiq = quantizeVec.mul(step).add(a);
                 FloatVector xiiq = v.sub(xiq);
                 xeVec = fma(v, xiiq, xeVec);
                 eVec = fma(xiiq, xiiq, eVec);
@@ -357,8 +359,9 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
         }
 
         for (; i < target.length; i++) {
+            quantize[i] = Math.round((Math.min(Math.max(target[i], a), b) - a) * invStep);
             // this is quantizing and then dequantizing the vector
-            float xiq = fma(step, Math.round((Math.min(Math.max(target[i], a), b) - a) * invStep), a);
+            float xiq = fma(step, quantize[i], a);
             // how much does the de-quantized value differ from the original value
             float xiiq = target[i] - xiq;
             e = fma(xiiq, xiiq, e);

+ 18 - 4
libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java

@@ -222,15 +222,20 @@ public class ESVectorUtilTests extends BaseVectorizationTests {
         vecVar /= size;
         float vecStd = (float) Math.sqrt(vecVar);
 
+        int[] destinationDefault = new int[size];
+        int[] destinationPanama = new int[size];
         for (byte bits : new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 }) {
             int points = 1 << bits;
             float[] initInterval = new float[2];
             OptimizedScalarQuantizer.initInterval(bits, vecStd, vecMean, min, max, initInterval);
             float step = ((initInterval[1] - initInterval[0]) / (points - 1f));
             float stepInv = 1f / step;
-            float expected = defaultedProvider.getVectorUtilSupport().calculateOSQLoss(vector, initInterval, step, stepInv, norm2, 0.1f);
-            float result = defOrPanamaProvider.getVectorUtilSupport().calculateOSQLoss(vector, initInterval, step, stepInv, norm2, 0.1f);
+            float expected = defaultedProvider.getVectorUtilSupport()
+                .calculateOSQLoss(vector, initInterval[0], initInterval[1], step, stepInv, norm2, 0.1f, destinationDefault);
+            float result = defOrPanamaProvider.getVectorUtilSupport()
+                .calculateOSQLoss(vector, initInterval[0], initInterval[1], step, stepInv, norm2, 0.1f, destinationPanama);
             assertEquals(expected, result, deltaEps);
+            assertArrayEquals(destinationDefault, destinationPanama);
         }
     }
 
@@ -240,6 +245,7 @@ public class ESVectorUtilTests extends BaseVectorizationTests {
         var vector = new float[size];
         var min = Float.MAX_VALUE;
         var max = -Float.MAX_VALUE;
+        var norm2 = 0f;
         float vecMean = 0;
         float vecVar = 0;
         for (int i = 0; i < size; ++i) {
@@ -250,9 +256,12 @@ public class ESVectorUtilTests extends BaseVectorizationTests {
             vecMean += delta / (i + 1);
             float delta2 = vector[i] - vecMean;
             vecVar += delta * delta2;
+            norm2 += vector[i] * vector[i];
         }
         vecVar /= size;
         float vecStd = (float) Math.sqrt(vecVar);
+        int[] destinationDefault = new int[size];
+        int[] destinationPanama = new int[size];
         for (byte bits : new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 }) {
             int points = 1 << bits;
             float[] initInterval = new float[2];
@@ -260,11 +269,16 @@ public class ESVectorUtilTests extends BaseVectorizationTests {
             float step = ((initInterval[1] - initInterval[0]) / (points - 1f));
             float stepInv = 1f / step;
             float[] expected = new float[5];
-            defaultedProvider.getVectorUtilSupport().calculateOSQGridPoints(vector, initInterval, points, stepInv, expected);
+            defaultedProvider.getVectorUtilSupport()
+                .calculateOSQLoss(vector, initInterval[0], initInterval[1], step, stepInv, norm2, 0.1f, destinationDefault);
+            defaultedProvider.getVectorUtilSupport().calculateOSQGridPoints(vector, destinationDefault, points, expected);
 
             float[] result = new float[5];
-            defOrPanamaProvider.getVectorUtilSupport().calculateOSQGridPoints(vector, initInterval, points, stepInv, result);
+            defOrPanamaProvider.getVectorUtilSupport()
+                .calculateOSQLoss(vector, initInterval[0], initInterval[1], step, stepInv, norm2, 0.1f, destinationPanama);
+            defOrPanamaProvider.getVectorUtilSupport().calculateOSQGridPoints(vector, destinationPanama, points, result);
             assertArrayEquals(expected, result, deltaEps);
+            assertArrayEquals(destinationDefault, destinationPanama);
         }
     }
 

+ 38 - 58
server/src/main/java/org/elasticsearch/index/codec/vectors/OptimizedScalarQuantizer.java

@@ -78,15 +78,20 @@ public class OptimizedScalarQuantizer {
             int points = (1 << bits[i]);
             // Linearly scale the interval to the standard deviation of the vector, ensuring we are within the min/max bounds
             initInterval(bits[i], vecStd, vecMean, min, max, intervalScratch);
-            optimizeIntervals(intervalScratch, vector, norm2, points);
+            boolean hasQuantization = optimizeIntervals(intervalScratch, destinations[i], vector, norm2, points);
             // Now we have the optimized intervals, quantize the vector
-            int sumQuery = ESVectorUtil.quantizeVectorWithIntervals(
-                vector,
-                destinations[i],
-                intervalScratch[0],
-                intervalScratch[1],
-                bits[i]
-            );
+            int sumQuery;
+            if (hasQuantization) {
+                sumQuery = getSumQuery(destinations[i]);
+            } else {
+                sumQuery = ESVectorUtil.quantizeVectorWithIntervals(
+                    vector,
+                    destinations[i],
+                    intervalScratch[0],
+                    intervalScratch[1],
+                    bits[i]
+                );
+            }
             results[i] = new QuantizationResult(
                 intervalScratch[0],
                 intervalScratch[1],
@@ -97,8 +102,7 @@ public class OptimizedScalarQuantizer {
         return results;
     }
 
-    // This method is only used for benchmarking purposes, it is not used in production
-    public QuantizationResult legacyScalarQuantize(float[] vector, byte[] destination, byte bits, float[] centroid) {
+    public QuantizationResult scalarQuantize(float[] vector, int[] destination, byte bits, float[] centroid) {
         assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector);
         assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid);
         assert vector.length <= destination.length;
@@ -117,49 +121,14 @@ public class OptimizedScalarQuantizer {
         float vecStd = (float) Math.sqrt(vecVar);
         // Linearly scale the interval to the standard deviation of the vector, ensuring we are within the min/max bounds
         initInterval(bits, vecStd, vecMean, min, max, intervalScratch);
-        optimizeIntervals(intervalScratch, vector, norm2, points);
-        float nSteps = ((1 << bits) - 1);
+        boolean hasQuantization = optimizeIntervals(intervalScratch, destination, vector, norm2, points);
         // Now we have the optimized intervals, quantize the vector
-        float a = intervalScratch[0];
-        float b = intervalScratch[1];
-        float step = (b - a) / nSteps;
-        int sumQuery = 0;
-        for (int h = 0; h < vector.length; h++) {
-            float xi = (float) clamp(vector[h], a, b);
-            int assignment = Math.round((xi - a) / step);
-            sumQuery += assignment;
-            destination[h] = (byte) assignment;
-        }
-        return new QuantizationResult(
-            intervalScratch[0],
-            intervalScratch[1],
-            similarityFunction == EUCLIDEAN ? norm2 : statsScratch[5],
-            sumQuery
-        );
-    }
-
-    public QuantizationResult scalarQuantize(float[] vector, int[] destination, byte bits, float[] centroid) {
-        assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector);
-        assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid);
-        assert vector.length <= destination.length;
-        assert bits > 0 && bits <= 8;
-        int points = 1 << bits;
-        if (similarityFunction == EUCLIDEAN) {
-            ESVectorUtil.centerAndCalculateOSQStatsEuclidean(vector, centroid, vector, statsScratch);
+        int sumQuery;
+        if (hasQuantization) {
+            sumQuery = getSumQuery(destination);
         } else {
-            ESVectorUtil.centerAndCalculateOSQStatsDp(vector, centroid, vector, statsScratch);
+            sumQuery = ESVectorUtil.quantizeVectorWithIntervals(vector, destination, intervalScratch[0], intervalScratch[1], bits);
         }
-        float vecMean = statsScratch[0];
-        float vecVar = statsScratch[1];
-        float norm2 = statsScratch[2];
-        float min = statsScratch[3];
-        float max = statsScratch[4];
-        float vecStd = (float) Math.sqrt(vecVar);
-        // Linearly scale the interval to the standard deviation of the vector, ensuring we are within the min/max bounds
-        initInterval(bits, vecStd, vecMean, min, max, intervalScratch);
-        optimizeIntervals(intervalScratch, vector, norm2, points);
-        // Now we have the optimized intervals, quantize the vector
-        int sumQuery = ESVectorUtil.quantizeVectorWithIntervals(vector, destination, intervalScratch[0], intervalScratch[1], bits);
         return new QuantizationResult(
             intervalScratch[0],
             intervalScratch[1],
@@ -176,16 +145,18 @@ public class OptimizedScalarQuantizer {
      * @param vector raw vector
      * @param norm2 squared norm of the vector
      * @param points number of quantization points
+     *
+     * @return true if {@param destination} contains the quantize vector and we can skip the quantization.
      */
-    private void optimizeIntervals(float[] initInterval, float[] vector, float norm2, int points) {
-        double initialLoss = ESVectorUtil.calculateOSQLoss(vector, initInterval, points, norm2, lambda);
+    private boolean optimizeIntervals(float[] initInterval, int[] destination, float[] vector, float norm2, int points) {
+        double initialLoss = ESVectorUtil.calculateOSQLoss(vector, initInterval[0], initInterval[1], points, norm2, lambda, destination);
         final float scale = (1.0f - lambda) / norm2;
         if (Float.isFinite(scale) == false) {
-            return;
+            return true;
         }
         for (int i = 0; i < iters; ++i) {
             // calculate the grid points for coordinate descent
-            ESVectorUtil.calculateOSQGridPoints(vector, initInterval, points, gridScratch);
+            ESVectorUtil.calculateOSQGridPoints(vector, destination, points, gridScratch);
             float daa = gridScratch[0];
             float dab = gridScratch[1];
             float dbb = gridScratch[2];
@@ -197,26 +168,35 @@ public class OptimizedScalarQuantizer {
             // its possible that the determinant is 0, in which case we can't update the interval
             double det = m0 * m2 - m1 * m1;
             if (det == 0) {
-                return;
+                return true;
             }
             float aOpt = (float) ((m2 * dax - m1 * dbx) / det);
             float bOpt = (float) ((m0 * dbx - m1 * dax) / det);
             // If there is no change in the interval, we can stop
             if ((Math.abs(initInterval[0] - aOpt) < 1e-8 && Math.abs(initInterval[1] - bOpt) < 1e-8)) {
-                return;
+                return true;
             }
-            double newLoss = ESVectorUtil.calculateOSQLoss(vector, new float[] { aOpt, bOpt }, points, norm2, lambda);
+            double newLoss = ESVectorUtil.calculateOSQLoss(vector, aOpt, bOpt, points, norm2, lambda, destination);
             // If the new loss is worse, don't update the interval and exit
             // This optimization, unlike kMeans, does not always converge to better loss
             // So exit if we are getting worse
             if (newLoss > initialLoss) {
-                return;
+                return false;
             }
             // Update the interval and go again
             initInterval[0] = aOpt;
             initInterval[1] = bOpt;
             initialLoss = newLoss;
         }
+        return true;
+    }
+
+    private static int getSumQuery(int[] quantize) {
+        int sum = 0;
+        for (int q : quantize) {
+            sum += q;
+        }
+        return sum;
     }
 
     private static double clamp(double x, double a, double b) {

+ 0 - 14
server/src/test/java/org/elasticsearch/index/codec/vectors/OptimizedScalarQuantizerTests.java

@@ -52,7 +52,6 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {
         float[] scratch = new float[dims];
         for (byte bit : ALL_BITS) {
             float eps = (1f / (float) (1 << (bit)));
-            byte[] legacyDestination = new byte[dims];
             int[] destination = new int[dims];
             for (int i = 0; i < numVectors; ++i) {
                 System.arraycopy(vectors[i], 0, scratch, 0, dims);
@@ -73,19 +72,6 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {
                 }
                 mae /= dims;
                 assertTrue("bits: " + bit + " mae: " + mae + " > eps: " + eps, mae <= eps);
-
-                // check we get the same result from the int version
-                System.arraycopy(vectors[i], 0, scratch, 0, dims);
-                OptimizedScalarQuantizer.QuantizationResult intResults = osq.legacyScalarQuantize(
-                    scratch,
-                    legacyDestination,
-                    bit,
-                    centroid
-                );
-                assertEquals(result, intResults);
-                for (int h = 0; h < dims; ++h) {
-                    assertEquals((byte) destination[h], legacyDestination[h]);
-                }
             }
         }
     }