Browse Source

optimize OptimizedScalarQuantizer#scalarQuantize (#129874)

optimize OptimizedScalarQuantizer#scalarQuantize when destination can optimize 
OptimizedScalarQuantizer#scalarQuantize when destination can be an integer array
Ignacio Vera 3 months ago
parent
commit
f81d35536d
17 changed files with 227 additions and 51 deletions
  1. 14 5
      benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OptimizedScalarQuantizerBenchmark.java
  2. 21 0
      libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java
  3. 14 0
      libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java
  4. 2 0
      libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java
  5. 28 0
      libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java
  6. 23 0
      libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java
  7. 29 0
      server/src/main/java/org/elasticsearch/index/codec/vectors/BQSpaceUtils.java
  8. 1 1
      server/src/main/java/org/elasticsearch/index/codec/vectors/BQVectorUtils.java
  9. 8 4
      server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java
  10. 6 2
      server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java
  11. 2 2
      server/src/main/java/org/elasticsearch/index/codec/vectors/DiskBBQBulkWriter.java
  12. 40 13
      server/src/main/java/org/elasticsearch/index/codec/vectors/OptimizedScalarQuantizer.java
  13. 1 1
      server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java
  14. 5 5
      server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java
  15. 4 4
      server/src/test/java/org/elasticsearch/index/codec/vectors/BQVectorUtilsTests.java
  16. 28 13
      server/src/test/java/org/elasticsearch/index/codec/vectors/OptimizedScalarQuantizerTests.java
  17. 1 1
      server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormatTests.java

+ 14 - 5
benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OptimizedScalarQuantizerBenchmark.java

@@ -43,7 +43,8 @@ public class OptimizedScalarQuantizerBenchmark {
 
     float[] vector;
     float[] centroid;
-    byte[] destination;
+    byte[] legacyDestination;
+    int[] destination;
 
     @Param({ "1", "4", "7" })
     byte bits;
@@ -54,7 +55,8 @@ public class OptimizedScalarQuantizerBenchmark {
     public void init() {
         ThreadLocalRandom random = ThreadLocalRandom.current();
         // random byte arrays for binary methods
-        destination = new byte[dims];
+        legacyDestination = new byte[dims];
+        destination = new int[dims];
         vector = new float[dims];
         centroid = new float[dims];
         for (int i = 0; i < dims; ++i) {
@@ -65,13 +67,20 @@ public class OptimizedScalarQuantizerBenchmark {
 
     @Benchmark
     public byte[] scalar() {
-        osq.scalarQuantize(vector, destination, bits, centroid);
-        return destination;
+        osq.legacyScalarQuantize(vector, legacyDestination, bits, centroid);
+        return legacyDestination;
+    }
+
+    @Benchmark
+    @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
+    public byte[] legacyVector() {
+        osq.legacyScalarQuantize(vector, legacyDestination, bits, centroid);
+        return legacyDestination;
     }
 
     @Benchmark
     @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
-    public byte[] vector() {
+    public int[] vector() {
         osq.scalarQuantize(vector, destination, bits, centroid);
         return destination;
     }

+ 21 - 0
libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java

@@ -258,4 +258,25 @@ public class ESVectorUtil {
         }
         return IMPL.soarDistance(v1, centroid, originalResidual, soarLambda, rnorm);
     }
+
+    /**
+     * Optimized-scalar quantization of the provided vector to the provided destination array.
+     *
+     * @param vector the vector to quantize
+     * @param destination the array to store the result
+     * @param lowInterval the minimum value, lower values in the original array will be replaced by this value
+     * @param upperInterval the maximum value, bigger values in the original array will be replaced by this value
+     * @param bit the number of bits to use for quantization, must be between 1 and 8
+     *
+     * @return return the sum of all the elements of the resulting quantized vector.
+     */
+    public static int quantizeVectorWithIntervals(float[] vector, int[] destination, float lowInterval, float upperInterval, byte bit) {
+        if (vector.length > destination.length) {
+            throw new IllegalArgumentException("vector dimensions differ: " + vector.length + "!=" + destination.length);
+        }
+        if (bit <= 0 || bit > Byte.SIZE) {
+            throw new IllegalArgumentException("bit must be between 1 and 8, but was: " + bit);
+        }
+        return IMPL.quantizeVectorWithIntervals(vector, destination, lowInterval, upperInterval, bit);
+    }
 }

+ 14 - 0
libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java

@@ -269,4 +269,18 @@ final class DefaultESVectorUtilSupport implements ESVectorUtilSupport {
         }
         return ret;
     }
+
+    @Override
+    public int quantizeVectorWithIntervals(float[] vector, int[] destination, float lowInterval, float upperInterval, byte bits) {
+        float nSteps = ((1 << bits) - 1);
+        float step = (upperInterval - lowInterval) / nSteps;
+        int sumQuery = 0;
+        for (int h = 0; h < vector.length; h++) {
+            float xi = Math.min(Math.max(vector[h], lowInterval), upperInterval);
+            int assignment = Math.round((xi - lowInterval) / step);
+            sumQuery += assignment;
+            destination[h] = assignment;
+        }
+        return sumQuery;
+    }
 }

+ 2 - 0
libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java

@@ -39,4 +39,6 @@ public interface ESVectorUtilSupport {
 
     float soarDistance(float[] v1, float[] centroid, float[] originalResidual, float soarLambda, float rnorm);
 
+    int quantizeVectorWithIntervals(float[] vector, int[] quantize, float lowInterval, float upperInterval, byte bit);
+
 }

+ 28 - 0
libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java

@@ -791,4 +791,32 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
 
         return sum;
     }
+
+    @Override
+    public int quantizeVectorWithIntervals(float[] vector, int[] destination, float lowInterval, float upperInterval, byte bits) {
+        float nSteps = ((1 << bits) - 1);
+        float step = (upperInterval - lowInterval) / nSteps;
+        int sumQuery = 0;
+        int i = 0;
+        if (vector.length > 2 * FLOAT_SPECIES.length()) {
+            int limit = FLOAT_SPECIES.loopBound(vector.length);
+            FloatVector lowVec = FloatVector.broadcast(FLOAT_SPECIES, lowInterval);
+            FloatVector upperVec = FloatVector.broadcast(FLOAT_SPECIES, upperInterval);
+            FloatVector stepVec = FloatVector.broadcast(FLOAT_SPECIES, step);
+            for (; i < limit; i += FLOAT_SPECIES.length()) {
+                FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, vector, i);
+                FloatVector xi = v.max(lowVec).min(upperVec); // clamp
+                IntVector assignment = xi.sub(lowVec).div(stepVec).add(0.5f).convert(VectorOperators.F2I, 0).reinterpretAsInts(); // round
+                sumQuery += assignment.reduceLanes(ADD);
+                assignment.intoArray(destination, i);
+            }
+        }
+        for (; i < vector.length; i++) {
+            float xi = Math.min(Math.max(vector[i], lowInterval), upperInterval);
+            int assignment = Math.round((xi - lowInterval) / step);
+            sumQuery += assignment;
+            destination[i] = assignment;
+        }
+        return sumQuery;
+    }
 }

+ 23 - 0
libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java

@@ -286,6 +286,29 @@ public class ESVectorUtilTests extends BaseVectorizationTests {
         assertEquals(expected, result, deltaEps);
     }
 
+    public void testQuantizeVectorWithIntervals() {
+        int vectorSize = randomIntBetween(1, 2048);
+        float[] vector = new float[vectorSize];
+
+        byte bits = (byte) randomIntBetween(1, 8);
+        for (int i = 0; i < vectorSize; ++i) {
+            vector[i] = random().nextFloat();
+        }
+        float low = random().nextFloat();
+        float high = random().nextFloat();
+        if (low > high) {
+            float tmp = low;
+            low = high;
+            high = tmp;
+        }
+        int[] quantizeExpected = new int[vectorSize];
+        int[] quantizeResult = new int[vectorSize];
+        var expected = defaultedProvider.getVectorUtilSupport().quantizeVectorWithIntervals(vector, quantizeExpected, low, high, bits);
+        var result = defOrPanamaProvider.getVectorUtilSupport().quantizeVectorWithIntervals(vector, quantizeResult, low, high, bits);
+        assertArrayEquals(quantizeExpected, quantizeResult);
+        assertEquals(expected, result, 0f);
+    }
+
     void testIpByteBinImpl(ToLongBiFunction<byte[], byte[]> ipByteBinFunc) {
         int iterations = atLeast(50);
         for (int i = 0; i < iterations; i++) {

+ 29 - 0
server/src/main/java/org/elasticsearch/index/codec/vectors/BQSpaceUtils.java

@@ -57,4 +57,33 @@ public class BQSpaceUtils {
             quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
         }
     }
+
+    /**
+     * Same as {@link #transposeHalfByte(byte[], byte[])} but the input vector is provided as
+     * an array of integers.
+     *
+     * @param q the query vector, assumed to be half-byte quantized with values between 0 and 15
+     * @param quantQueryByte the byte array to store the transposed query vector
+     * */
+    public static void transposeHalfByte(int[] q, byte[] quantQueryByte) {
+        for (int i = 0; i < q.length;) {
+            assert q[i] >= 0 && q[i] <= 15;
+            int lowerByte = 0;
+            int lowerMiddleByte = 0;
+            int upperMiddleByte = 0;
+            int upperByte = 0;
+            for (int j = 7; j >= 0 && i < q.length; j--) {
+                lowerByte |= (q[i] & 1) << j;
+                lowerMiddleByte |= ((q[i] >> 1) & 1) << j;
+                upperMiddleByte |= ((q[i] >> 2) & 1) << j;
+                upperByte |= ((q[i] >> 3) & 1) << j;
+                i++;
+            }
+            int index = ((i + 7) / 8) - 1;
+            quantQueryByte[index] = (byte) lowerByte;
+            quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte;
+            quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte;
+            quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
+        }
+    }
 }

+ 1 - 1
server/src/main/java/org/elasticsearch/index/codec/vectors/BQVectorUtils.java

@@ -40,7 +40,7 @@ public class BQVectorUtils {
         return Math.abs(l1norm - 1.0d) <= EPSILON;
     }
 
-    public static void packAsBinary(byte[] vector, byte[] packed) {
+    public static void packAsBinary(int[] vector, byte[] packed) {
         for (int i = 0; i < vector.length;) {
             byte result = 0;
             for (int j = 7; j >= 0 && i < vector.length; j--) {

+ 8 - 4
server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java

@@ -52,13 +52,17 @@ public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeap
         final FieldEntry fieldEntry = fields.get(fieldInfo.number);
         final float globalCentroidDp = fieldEntry.globalCentroidDp();
         final OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
-        final byte[] quantized = new byte[targetQuery.length];
+        final int[] scratch = new int[targetQuery.length];
         final OptimizedScalarQuantizer.QuantizationResult queryParams = scalarQuantizer.scalarQuantize(
             ArrayUtil.copyArray(targetQuery),
-            quantized,
+            scratch,
             (byte) 4,
             fieldEntry.globalCentroid()
         );
+        final byte[] quantized = new byte[targetQuery.length];
+        for (int i = 0; i < quantized.length; i++) {
+            quantized[i] = (byte) scratch[i];
+        }
         final ES91Int4VectorsScorer scorer = ESVectorUtil.getES91Int4VectorsScorer(centroids, fieldInfo.getVectorDimension());
         return new CentroidQueryScorer() {
             int currentCentroid = -1;
@@ -182,7 +186,7 @@ public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeap
         DocIdsWriter docIdsWriter = new DocIdsWriter();
 
         final float[] scratch;
-        final byte[] quantizationScratch;
+        final int[] quantizationScratch;
         final byte[] quantizedQueryScratch;
         final OptimizedScalarQuantizer quantizer;
         final float[] correctiveValues = new float[3];
@@ -202,7 +206,7 @@ public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeap
             this.needsScoring = needsScoring;
 
             scratch = new float[target.length];
-            quantizationScratch = new byte[target.length];
+            quantizationScratch = new int[target.length];
             final int discretizedDimensions = discretize(fieldInfo.getVectorDimension(), 64);
             quantizedQueryScratch = new byte[QUERY_BITS * discretizedDimensions / 8];
             quantizedByteLength = discretizedDimensions / 8 + (Float.BYTES * 3) + Short.BYTES;

+ 6 - 2
server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java

@@ -122,8 +122,9 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
     static void writeCentroids(float[][] centroids, FieldInfo fieldInfo, float[] globalCentroid, IndexOutput centroidOutput)
         throws IOException {
         final OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
-        byte[] quantizedScratch = new byte[fieldInfo.getVectorDimension()];
+        int[] quantizedScratch = new int[fieldInfo.getVectorDimension()];
         float[] centroidScratch = new float[fieldInfo.getVectorDimension()];
+        final byte[] quantized = new byte[fieldInfo.getVectorDimension()];
         // TODO do we want to store these distances as well for future use?
         // TODO: sort centroids by global centroid (was doing so previously here)
         // TODO: sorting tanks recall possibly because centroids ordinals no longer are aligned
@@ -135,7 +136,10 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
                 (byte) 4,
                 globalCentroid
             );
-            writeQuantizedValue(centroidOutput, quantizedScratch, result);
+            for (int i = 0; i < quantizedScratch.length; i++) {
+                quantized[i] = (byte) quantizedScratch[i];
+            }
+            writeQuantizedValue(centroidOutput, quantized, result);
         }
         final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
         for (float[] centroid : centroids) {

+ 2 - 2
server/src/main/java/org/elasticsearch/index/codec/vectors/DiskBBQBulkWriter.java

@@ -66,13 +66,13 @@ public abstract class DiskBBQBulkWriter {
 
     public static class OneBitDiskBBQBulkWriter extends DiskBBQBulkWriter {
         private final byte[] binarized;
-        private final byte[] initQuantized;
+        private final int[] initQuantized;
         private final OptimizedScalarQuantizer.QuantizationResult[] corrections;
 
         public OneBitDiskBBQBulkWriter(int bulkSize, OptimizedScalarQuantizer quantizer, FloatVectorValues fvv, IndexOutput out) {
             super(bulkSize, quantizer, fvv, out);
             this.binarized = new byte[discretize(fvv.dimension(), 64) / 8];
-            this.initQuantized = new byte[fvv.dimension()];
+            this.initQuantized = new int[fvv.dimension()];
             this.corrections = new OptimizedScalarQuantizer.QuantizationResult[bulkSize];
         }
 

+ 40 - 13
server/src/main/java/org/elasticsearch/index/codec/vectors/OptimizedScalarQuantizer.java

@@ -57,7 +57,7 @@ public class OptimizedScalarQuantizer {
 
     public record QuantizationResult(float lowerInterval, float upperInterval, float additionalCorrection, int quantizedComponentSum) {}
 
-    public QuantizationResult[] multiScalarQuantize(float[] vector, byte[][] destinations, byte[] bits, float[] centroid) {
+    public QuantizationResult[] multiScalarQuantize(float[] vector, int[][] destinations, byte[] bits, float[] centroid) {
         assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector);
         assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid);
         assert bits.length == destinations.length;
@@ -79,18 +79,14 @@ public class OptimizedScalarQuantizer {
             // Linearly scale the interval to the standard deviation of the vector, ensuring we are within the min/max bounds
             initInterval(bits[i], vecStd, vecMean, min, max, intervalScratch);
             optimizeIntervals(intervalScratch, vector, norm2, points);
-            float nSteps = ((1 << bits[i]) - 1);
-            float a = intervalScratch[0];
-            float b = intervalScratch[1];
-            float step = (b - a) / nSteps;
-            int sumQuery = 0;
             // Now we have the optimized intervals, quantize the vector
-            for (int h = 0; h < vector.length; h++) {
-                float xi = (float) clamp(vector[h], a, b);
-                int assignment = Math.round((xi - a) / step);
-                sumQuery += assignment;
-                destinations[i][h] = (byte) assignment;
-            }
+            int sumQuery = ESVectorUtil.quantizeVectorWithIntervals(
+                vector,
+                destinations[i],
+                intervalScratch[0],
+                intervalScratch[1],
+                bits[i]
+            );
             results[i] = new QuantizationResult(
                 intervalScratch[0],
                 intervalScratch[1],
@@ -101,7 +97,8 @@ public class OptimizedScalarQuantizer {
         return results;
     }
 
-    public QuantizationResult scalarQuantize(float[] vector, byte[] destination, byte bits, float[] centroid) {
+    // This method is only used for benchmarking purposes, it is not used in production
+    public QuantizationResult legacyScalarQuantize(float[] vector, byte[] destination, byte bits, float[] centroid) {
         assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector);
         assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid);
         assert vector.length <= destination.length;
@@ -141,6 +138,36 @@ public class OptimizedScalarQuantizer {
         );
     }
 
+    public QuantizationResult scalarQuantize(float[] vector, int[] destination, byte bits, float[] centroid) {
+        assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector);
+        assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid);
+        assert vector.length <= destination.length;
+        assert bits > 0 && bits <= 8;
+        int points = 1 << bits;
+        if (similarityFunction == EUCLIDEAN) {
+            ESVectorUtil.centerAndCalculateOSQStatsEuclidean(vector, centroid, vector, statsScratch);
+        } else {
+            ESVectorUtil.centerAndCalculateOSQStatsDp(vector, centroid, vector, statsScratch);
+        }
+        float vecMean = statsScratch[0];
+        float vecVar = statsScratch[1];
+        float norm2 = statsScratch[2];
+        float min = statsScratch[3];
+        float max = statsScratch[4];
+        float vecStd = (float) Math.sqrt(vecVar);
+        // Linearly scale the interval to the standard deviation of the vector, ensuring we are within the min/max bounds
+        initInterval(bits, vecStd, vecMean, min, max, intervalScratch);
+        optimizeIntervals(intervalScratch, vector, norm2, points);
+        // Now we have the optimized intervals, quantize the vector
+        int sumQuery = ESVectorUtil.quantizeVectorWithIntervals(vector, destination, intervalScratch[0], intervalScratch[1], bits);
+        return new QuantizationResult(
+            intervalScratch[0],
+            intervalScratch[1],
+            similarityFunction == EUCLIDEAN ? norm2 : statsScratch[5],
+            sumQuery
+        );
+    }
+
     /**
      * Optimize the quantization interval for the given vector. This is done via a coordinate descent trying to minimize the quantization
      * loss. Note, the loss is not always guaranteed to decrease, so we have a maximum number of iterations and will exit early if the

+ 1 - 1
server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java

@@ -77,7 +77,7 @@ public class ES818BinaryFlatVectorsScorer implements FlatVectorsScorer {
                 VectorUtil.l2normalize(copy);
             }
             target = copy;
-            byte[] initial = new byte[target.length];
+            int[] initial = new int[target.length];
             byte[] quantized = new byte[BQSpaceUtils.B_QUERY * binarizedVectors.discretizedDimensions() / 8];
             OptimizedScalarQuantizer.QuantizationResult queryCorrections = quantizer.scalarQuantize(target, initial, (byte) 4, centroid);
             BQSpaceUtils.transposeHalfByte(initial, quantized);

+ 5 - 5
server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java

@@ -198,7 +198,7 @@ public class ES818BinaryQuantizedVectorsWriter extends FlatVectorsWriter {
     private void writeBinarizedVectors(FieldWriter fieldData, float[] clusterCenter, OptimizedScalarQuantizer scalarQuantizer)
         throws IOException {
         int discreteDims = BQVectorUtils.discretize(fieldData.fieldInfo.getVectorDimension(), 64);
-        byte[] quantizationScratch = new byte[discreteDims];
+        int[] quantizationScratch = new int[discreteDims];
         byte[] vector = new byte[discreteDims / 8];
         for (int i = 0; i < fieldData.getVectors().size(); i++) {
             float[] v = fieldData.getVectors().get(i);
@@ -246,7 +246,7 @@ public class ES818BinaryQuantizedVectorsWriter extends FlatVectorsWriter {
         OptimizedScalarQuantizer scalarQuantizer
     ) throws IOException {
         int discreteDims = BQVectorUtils.discretize(fieldData.fieldInfo.getVectorDimension(), 64);
-        byte[] quantizationScratch = new byte[discreteDims];
+        int[] quantizationScratch = new int[discreteDims];
         byte[] vector = new byte[discreteDims / 8];
         for (int ordinal : ordMap) {
             float[] v = fieldData.getVectors().get(ordinal);
@@ -364,7 +364,7 @@ public class ES818BinaryQuantizedVectorsWriter extends FlatVectorsWriter {
     ) throws IOException {
         int discretizedDimension = BQVectorUtils.discretize(floatVectorValues.dimension(), 64);
         DocsWithFieldSet docsWithField = new DocsWithFieldSet();
-        byte[][] quantizationScratch = new byte[2][floatVectorValues.dimension()];
+        int[][] quantizationScratch = new int[2][floatVectorValues.dimension()];
         byte[] toIndex = new byte[discretizedDimension / 8];
         byte[] toQuery = new byte[(discretizedDimension / 8) * BQSpaceUtils.B_QUERY];
         KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator();
@@ -801,7 +801,7 @@ public class ES818BinaryQuantizedVectorsWriter extends FlatVectorsWriter {
     static class BinarizedFloatVectorValues extends BinarizedByteVectorValues {
         private OptimizedScalarQuantizer.QuantizationResult corrections;
         private final byte[] binarized;
-        private final byte[] initQuantized;
+        private final int[] initQuantized;
         private final float[] centroid;
         private final FloatVectorValues values;
         private final OptimizedScalarQuantizer quantizer;
@@ -812,7 +812,7 @@ public class ES818BinaryQuantizedVectorsWriter extends FlatVectorsWriter {
             this.values = delegate;
             this.quantizer = quantizer;
             this.binarized = new byte[BQVectorUtils.discretize(delegate.dimension(), 64) / 8];
-            this.initQuantized = new byte[delegate.dimension()];
+            this.initQuantized = new int[delegate.dimension()];
             this.centroid = centroid;
         }
 

+ 4 - 4
server/src/test/java/org/elasticsearch/index/codec/vectors/BQVectorUtilsTests.java

@@ -40,25 +40,25 @@ public class BQVectorUtilsTests extends LuceneTestCase {
 
     public void testPackAsBinary() {
         // 5 bits
-        byte[] toPack = new byte[] { 1, 1, 0, 0, 1 };
+        int[] toPack = new int[] { 1, 1, 0, 0, 1 };
         byte[] packed = new byte[1];
         BQVectorUtils.packAsBinary(toPack, packed);
         assertArrayEquals(new byte[] { (byte) 0b11001000 }, packed);
 
         // 8 bits
-        toPack = new byte[] { 1, 1, 0, 0, 1, 0, 1, 0 };
+        toPack = new int[] { 1, 1, 0, 0, 1, 0, 1, 0 };
         packed = new byte[1];
         BQVectorUtils.packAsBinary(toPack, packed);
         assertArrayEquals(new byte[] { (byte) 0b11001010 }, packed);
 
         // 10 bits
-        toPack = new byte[] { 1, 1, 0, 0, 1, 0, 1, 0, 1, 1 };
+        toPack = new int[] { 1, 1, 0, 0, 1, 0, 1, 0, 1, 1 };
         packed = new byte[2];
         BQVectorUtils.packAsBinary(toPack, packed);
         assertArrayEquals(new byte[] { (byte) 0b11001010, (byte) 0b11000000 }, packed);
 
         // 16 bits
-        toPack = new byte[] { 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0 };
+        toPack = new int[] { 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0 };
         packed = new byte[2];
         BQVectorUtils.packAsBinary(toPack, packed);
         assertArrayEquals(new byte[] { (byte) 0b11001010, (byte) 0b11100110 }, packed);

+ 28 - 13
server/src/test/java/org/elasticsearch/index/codec/vectors/OptimizedScalarQuantizerTests.java

@@ -19,7 +19,7 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {
 
     static final byte[] ALL_BITS = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 };
 
-    static float[] deQuantize(byte[] quantized, byte bits, float[] interval, float[] centroid) {
+    static float[] deQuantize(int[] quantized, byte bits, float[] interval, float[] centroid) {
         float[] dequantized = new float[quantized.length];
         float a = interval[0];
         float b = interval[1];
@@ -52,10 +52,12 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {
         float[] scratch = new float[dims];
         for (byte bit : ALL_BITS) {
             float eps = (1f / (float) (1 << (bit)));
-            byte[] destination = new byte[dims];
+            byte[] legacyDestination = new byte[dims];
+            int[] destination = new int[dims];
             for (int i = 0; i < numVectors; ++i) {
                 System.arraycopy(vectors[i], 0, scratch, 0, dims);
                 OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(scratch, destination, bit, centroid);
+
                 assertValidResults(result);
                 assertValidQuantizedRange(destination, bit);
 
@@ -71,6 +73,19 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {
                 }
                 mae /= dims;
                 assertTrue("bits: " + bit + " mae: " + mae + " > eps: " + eps, mae <= eps);
+
+                // check we get the same result from the int version
+                System.arraycopy(vectors[i], 0, scratch, 0, dims);
+                OptimizedScalarQuantizer.QuantizationResult intResults = osq.legacyScalarQuantize(
+                    scratch,
+                    legacyDestination,
+                    bit,
+                    centroid
+                );
+                assertEquals(result, intResults);
+                for (int h = 0; h < dims; ++h) {
+                    assertEquals((byte) destination[h], legacyDestination[h]);
+                }
             }
         }
     }
@@ -84,18 +99,18 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {
             float[] vector = new float[4096];
             float[] centroid = new float[4096];
             OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(vectorSimilarityFunction);
-            byte[][] destinations = new byte[MINIMUM_MSE_GRID.length][4096];
+            int[][] destinations = new int[MINIMUM_MSE_GRID.length][4096];
             OptimizedScalarQuantizer.QuantizationResult[] results = osq.multiScalarQuantize(vector, destinations, ALL_BITS, centroid);
             assertEquals(MINIMUM_MSE_GRID.length, results.length);
             assertValidResults(results);
-            for (byte[] destination : destinations) {
-                assertArrayEquals(new byte[4096], destination);
+            for (int[] destination : destinations) {
+                assertArrayEquals(new int[4096], destination);
             }
-            byte[] destination = new byte[4096];
+            int[] destination = new int[4096];
             for (byte bit : ALL_BITS) {
                 OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(vector, destination, bit, centroid);
                 assertValidResults(result);
-                assertArrayEquals(new byte[4096], destination);
+                assertArrayEquals(new int[4096], destination);
             }
         }
 
@@ -108,7 +123,7 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {
                 VectorUtil.l2normalize(centroid);
             }
             OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(vectorSimilarityFunction);
-            byte[][] destinations = new byte[MINIMUM_MSE_GRID.length][1];
+            int[][] destinations = new int[MINIMUM_MSE_GRID.length][1];
             OptimizedScalarQuantizer.QuantizationResult[] results = osq.multiScalarQuantize(vector, destinations, ALL_BITS, centroid);
             assertEquals(MINIMUM_MSE_GRID.length, results.length);
             assertValidResults(results);
@@ -122,7 +137,7 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {
                     VectorUtil.l2normalize(vector);
                     VectorUtil.l2normalize(centroid);
                 }
-                byte[] destination = new byte[1];
+                int[] destination = new int[1];
                 OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(vector, destination, bit, centroid);
                 assertValidResults(result);
                 assertValidQuantizedRange(destination, bit);
@@ -150,7 +165,7 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {
                 VectorUtil.l2normalize(centroid);
             }
             OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(vectorSimilarityFunction);
-            byte[][] destinations = new byte[MINIMUM_MSE_GRID.length][dims];
+            int[][] destinations = new int[MINIMUM_MSE_GRID.length][dims];
             OptimizedScalarQuantizer.QuantizationResult[] results = osq.multiScalarQuantize(copy, destinations, ALL_BITS, centroid);
             assertEquals(MINIMUM_MSE_GRID.length, results.length);
             assertValidResults(results);
@@ -158,7 +173,7 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {
                 assertValidQuantizedRange(destinations[i], ALL_BITS[i]);
             }
             for (byte bit : ALL_BITS) {
-                byte[] destination = new byte[dims];
+                int[] destination = new int[dims];
                 System.arraycopy(vector, 0, copy, 0, dims);
                 if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) {
                     VectorUtil.l2normalize(copy);
@@ -171,8 +186,8 @@ public class OptimizedScalarQuantizerTests extends ESTestCase {
         }
     }
 
-    static void assertValidQuantizedRange(byte[] quantized, byte bits) {
-        for (byte b : quantized) {
+    static void assertValidQuantizedRange(int[] quantized, byte bits) {
+        for (int b : quantized) {
             if (bits < 8) {
                 assertTrue(b >= 0);
             }

+ 1 - 1
server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormatTests.java

@@ -243,7 +243,7 @@ public class ES818BinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFormat
                     assertEquals(centroid.length, dims);
 
                     OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(similarityFunction);
-                    byte[] quantizedVector = new byte[dims];
+                    int[] quantizedVector = new int[dims];
                     byte[] expectedVector = new byte[BQVectorUtils.discretize(dims, 64) / 8];
                     if (similarityFunction == VectorSimilarityFunction.COSINE) {
                         vectorValues = new ES818BinaryQuantizedVectorsWriter.NormalizedFloatVectorValues(vectorValues);