Browse Source

New bulk scorer for binary quantized vectors via optimized scalar quantization (#127189)

* New bulk scorer for binary quantized vectors via optimized scalar quantization

* fixing headers

* fixing tests
Benjamin Trent 5 months ago
parent
commit
74faf47121

+ 204 - 0
benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java

@@ -0,0 +1,204 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+package org.elasticsearch.benchmark.vector;
+
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.store.IOContext;
+import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.store.IndexOutput;
+import org.apache.lucene.store.MMapDirectory;
+import org.apache.lucene.util.VectorUtil;
+import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
+import org.elasticsearch.common.logging.LogConfigurator;
+import org.elasticsearch.simdvec.internal.vectorization.ES91OSQVectorsScorer;
+import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
+import org.openjdk.jmh.annotations.Benchmark;
+import org.openjdk.jmh.annotations.BenchmarkMode;
+import org.openjdk.jmh.annotations.Fork;
+import org.openjdk.jmh.annotations.Measurement;
+import org.openjdk.jmh.annotations.Mode;
+import org.openjdk.jmh.annotations.OutputTimeUnit;
+import org.openjdk.jmh.annotations.Param;
+import org.openjdk.jmh.annotations.Scope;
+import org.openjdk.jmh.annotations.Setup;
+import org.openjdk.jmh.annotations.State;
+import org.openjdk.jmh.annotations.Warmup;
+import org.openjdk.jmh.infra.Blackhole;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.util.Random;
+import java.util.concurrent.TimeUnit;
+
+@BenchmarkMode(Mode.Throughput)
+@OutputTimeUnit(TimeUnit.MILLISECONDS)
+@State(Scope.Benchmark)
+// first iteration is complete garbage, so make sure we really warmup
+@Warmup(iterations = 4, time = 1)
+// real iterations. not useful to spend tons of time here, better to fork more
+@Measurement(iterations = 5, time = 1)
+// engage some noise reduction
+@Fork(value = 1)
+public class OSQScorerBenchmark {
+
+    static {
+        LogConfigurator.configureESLogging(); // native access requires logging to be initialized
+    }
+
+    @Param({ "1024" })
+    int dims;
+
+    int length;
+
+    int numVectors = ES91OSQVectorsScorer.BULK_SIZE * 10;
+    int numQueries = 10;
+
+    byte[][] binaryVectors;
+    byte[][] binaryQueries;
+    OptimizedScalarQuantizer.QuantizationResult result;
+    float centroidDp;
+
+    byte[] scratch;
+    ES91OSQVectorsScorer scorer;
+
+    IndexInput in;
+
+    float[] scratchScores;
+    float[] corrections;
+
+    @Setup
+    public void setup() throws IOException {
+        Random random = new Random(123);
+
+        this.length = OptimizedScalarQuantizer.discretize(dims, 64) / 8;
+
+        binaryVectors = new byte[numVectors][length];
+        for (byte[] binaryVector : binaryVectors) {
+            random.nextBytes(binaryVector);
+        }
+
+        Directory dir = new MMapDirectory(Files.createTempDirectory("vectorData"));
+        IndexOutput out = dir.createOutput("vectors", IOContext.DEFAULT);
+        byte[] correctionBytes = new byte[14 * ES91OSQVectorsScorer.BULK_SIZE];
+        for (int i = 0; i < numVectors; i += ES91OSQVectorsScorer.BULK_SIZE) {
+            for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
+                out.writeBytes(binaryVectors[i + j], 0, binaryVectors[i + j].length);
+            }
+            random.nextBytes(correctionBytes);
+            out.writeBytes(correctionBytes, 0, correctionBytes.length);
+        }
+        out.close();
+        in = dir.openInput("vectors", IOContext.DEFAULT);
+
+        binaryQueries = new byte[numVectors][4 * length];
+        for (byte[] binaryVector : binaryVectors) {
+            random.nextBytes(binaryVector);
+        }
+        result = new OptimizedScalarQuantizer.QuantizationResult(
+            random.nextFloat(),
+            random.nextFloat(),
+            random.nextFloat(),
+            Short.toUnsignedInt((short) random.nextInt())
+        );
+        centroidDp = random.nextFloat();
+
+        scratch = new byte[length];
+        scorer = ESVectorizationProvider.getInstance().newES91OSQVectorsScorer(in, dims);
+        scratchScores = new float[16];
+        corrections = new float[3];
+    }
+
+    @Benchmark
+    @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
+    public void scoreFromArray(Blackhole bh) throws IOException {
+        for (int j = 0; j < numQueries; j++) {
+            in.seek(0);
+            for (int i = 0; i < numVectors; i++) {
+                in.readBytes(scratch, 0, length);
+                float qDist = VectorUtil.int4BitDotProduct(binaryQueries[j], scratch);
+                in.readFloats(corrections, 0, corrections.length);
+                int addition = Short.toUnsignedInt(in.readShort());
+                float score = scorer.score(
+                    result,
+                    VectorSimilarityFunction.EUCLIDEAN,
+                    centroidDp,
+                    corrections[0],
+                    corrections[1],
+                    addition,
+                    corrections[2],
+                    qDist
+                );
+                bh.consume(score);
+            }
+        }
+    }
+
+    @Benchmark
+    @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
+    public void scoreFromMemorySegmentOnlyVector(Blackhole bh) throws IOException {
+        for (int j = 0; j < numQueries; j++) {
+            in.seek(0);
+            for (int i = 0; i < numVectors; i++) {
+                float qDist = scorer.quantizeScore(binaryQueries[j]);
+                in.readFloats(corrections, 0, corrections.length);
+                int addition = Short.toUnsignedInt(in.readShort());
+                float score = scorer.score(
+                    result,
+                    VectorSimilarityFunction.EUCLIDEAN,
+                    centroidDp,
+                    corrections[0],
+                    corrections[1],
+                    addition,
+                    corrections[2],
+                    qDist
+                );
+                bh.consume(score);
+            }
+        }
+    }
+
+    @Benchmark
+    @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
+    public void scoreFromMemorySegmentOnlyVectorBulk(Blackhole bh) throws IOException {
+        for (int j = 0; j < numQueries; j++) {
+            in.seek(0);
+            for (int i = 0; i < numVectors; i += 16) {
+                scorer.quantizeScoreBulk(binaryQueries[j], ES91OSQVectorsScorer.BULK_SIZE, scratchScores);
+                for (int k = 0; k < ES91OSQVectorsScorer.BULK_SIZE; k++) {
+                    in.readFloats(corrections, 0, corrections.length);
+                    int addition = Short.toUnsignedInt(in.readShort());
+                    float score = scorer.score(
+                        result,
+                        VectorSimilarityFunction.EUCLIDEAN,
+                        centroidDp,
+                        corrections[0],
+                        corrections[1],
+                        addition,
+                        corrections[2],
+                        scratchScores[k]
+                    );
+                    bh.consume(score);
+                }
+            }
+        }
+    }
+
+    @Benchmark
+    @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
+    public void scoreFromMemorySegmentAllBulk(Blackhole bh) throws IOException {
+        for (int j = 0; j < numQueries; j++) {
+            in.seek(0);
+            for (int i = 0; i < numVectors; i += 16) {
+                scorer.scoreBulk(binaryQueries[j], result, VectorSimilarityFunction.EUCLIDEAN, centroidDp, scratchScores);
+                bh.consume(scratchScores);
+            }
+        }
+    }
+}

+ 9 - 0
libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java

@@ -9,6 +9,10 @@
 
 package org.elasticsearch.simdvec.internal.vectorization;
 
+import org.apache.lucene.store.IndexInput;
+
+import java.io.IOException;
+
 final class DefaultESVectorizationProvider extends ESVectorizationProvider {
     private final ESVectorUtilSupport vectorUtilSupport;
 
@@ -20,4 +24,9 @@ final class DefaultESVectorizationProvider extends ESVectorizationProvider {
     public ESVectorUtilSupport getVectorUtilSupport() {
         return vectorUtilSupport;
     }
+
+    @Override
+    public ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException {
+        return new ES91OSQVectorsScorer(input, dimension);
+    }
 }

+ 168 - 0
libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorsScorer.java

@@ -0,0 +1,168 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+package org.elasticsearch.simdvec.internal.vectorization;
+
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.util.BitUtil;
+import org.apache.lucene.util.VectorUtil;
+import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
+
+import java.io.IOException;
+
+import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
+import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
+
+/** Scorer for quantized vectors stored as an {@link IndexInput}. */
+public class ES91OSQVectorsScorer {
+
+    public static final int BULK_SIZE = 16;
+
+    protected static final float FOUR_BIT_SCALE = 1f / ((1 << 4) - 1);
+
+    /** The wrapper {@link IndexInput}. */
+    protected final IndexInput in;
+
+    protected final int length;
+    protected final int dimensions;
+
+    protected final float[] lowerIntervals = new float[BULK_SIZE];
+    protected final float[] upperIntervals = new float[BULK_SIZE];
+    protected final int[] targetComponentSums = new int[BULK_SIZE];
+    protected final float[] additionalCorrections = new float[BULK_SIZE];
+
+    /** Sole constructor, called by sub-classes. */
+    public ES91OSQVectorsScorer(IndexInput in, int dimensions) {
+        this.in = in;
+        this.dimensions = dimensions;
+        this.length = OptimizedScalarQuantizer.discretize(dimensions, 64) / 8;
+    }
+
+    /**
+     * compute the quantize distance between the provided quantized query and the quantized vector
+     * that is read from the wrapped {@link IndexInput}.
+     */
+    public long quantizeScore(byte[] q) throws IOException {
+        assert q.length == length * 4;
+        final int size = length;
+        long subRet0 = 0;
+        long subRet1 = 0;
+        long subRet2 = 0;
+        long subRet3 = 0;
+        int r = 0;
+        for (final int upperBound = size & -Long.BYTES; r < upperBound; r += Long.BYTES) {
+            final long value = in.readLong();
+            subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, r) & value);
+            subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, r + size) & value);
+            subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, r + 2 * size) & value);
+            subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, r + 3 * size) & value);
+        }
+        for (final int upperBound = size & -Integer.BYTES; r < upperBound; r += Integer.BYTES) {
+            final int value = in.readInt();
+            subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, r) & value);
+            subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, r + size) & value);
+            subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, r + 2 * size) & value);
+            subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, r + 3 * size) & value);
+        }
+        for (; r < size; r++) {
+            final byte value = in.readByte();
+            subRet0 += Integer.bitCount((q[r] & value) & 0xFF);
+            subRet1 += Integer.bitCount((q[r + size] & value) & 0xFF);
+            subRet2 += Integer.bitCount((q[r + 2 * size] & value) & 0xFF);
+            subRet3 += Integer.bitCount((q[r + 3 * size] & value) & 0xFF);
+        }
+        return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
+    }
+
+    /**
+     * compute the quantize distance between the provided quantized query and the quantized vectors
+     * that are read from the wrapped {@link IndexInput}. The number of quantized vectors to read is
+     * determined by {code count} and the results are stored in the provided {@code scores} array.
+     */
+    public void quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOException {
+        for (int i = 0; i < count; i++) {
+            scores[i] = quantizeScore(q);
+        }
+    }
+
+    /**
+     * Computes the score by applying the necessary corrections to the provided quantized distance.
+     */
+    public float score(
+        OptimizedScalarQuantizer.QuantizationResult queryCorrections,
+        VectorSimilarityFunction similarityFunction,
+        float centroidDp,
+        float lowerInterval,
+        float upperInterval,
+        int targetComponentSum,
+        float additionalCorrection,
+        float qcDist
+    ) {
+        float ax = lowerInterval;
+        // Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
+        float lx = upperInterval - ax;
+        float ay = queryCorrections.lowerInterval();
+        float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
+        float y1 = queryCorrections.quantizedComponentSum();
+        float score = ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist;
+        // For euclidean, we need to invert the score and apply the additional correction, which is
+        // assumed to be the squared l2norm of the centroid centered vectors.
+        if (similarityFunction == EUCLIDEAN) {
+            score = queryCorrections.additionalCorrection() + additionalCorrection - 2 * score;
+            return Math.max(1 / (1f + score), 0);
+        } else {
+            // For cosine and max inner product, we need to apply the additional correction, which is
+            // assumed to be the non-centered dot-product between the vector and the centroid
+            score += queryCorrections.additionalCorrection() + additionalCorrection - centroidDp;
+            if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
+                return VectorUtil.scaleMaxInnerProductScore(score);
+            }
+            return Math.max((1f + score) / 2f, 0);
+        }
+    }
+
+    /**
+     * compute the distance between the provided quantized query and the quantized vectors that are
+     * read from the wrapped {@link IndexInput}.
+     *
+     * <p>The number of vectors to score is defined by {@link #BULK_SIZE}. The expected format of the
+     * input is as follows: First the quantized vectors are read from the input,then all the lower
+     * intervals as floats, then all the upper intervals as floats, then all the target component sums
+     * as shorts, and finally all the additional corrections as floats.
+     *
+     * <p>The results are stored in the provided scores array.
+     */
+    public void scoreBulk(
+        byte[] q,
+        OptimizedScalarQuantizer.QuantizationResult queryCorrections,
+        VectorSimilarityFunction similarityFunction,
+        float centroidDp,
+        float[] scores
+    ) throws IOException {
+        quantizeScoreBulk(q, BULK_SIZE, scores);
+        in.readFloats(lowerIntervals, 0, BULK_SIZE);
+        in.readFloats(upperIntervals, 0, BULK_SIZE);
+        for (int i = 0; i < BULK_SIZE; i++) {
+            targetComponentSums[i] = Short.toUnsignedInt(in.readShort());
+        }
+        in.readFloats(additionalCorrections, 0, BULK_SIZE);
+        for (int i = 0; i < BULK_SIZE; i++) {
+            scores[i] = score(
+                queryCorrections,
+                similarityFunction,
+                centroidDp,
+                lowerIntervals[i],
+                upperIntervals[i],
+                targetComponentSums[i],
+                additionalCorrections[i],
+                scores[i]
+            );
+        }
+    }
+}

+ 6 - 0
libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java

@@ -9,6 +9,9 @@
 
 package org.elasticsearch.simdvec.internal.vectorization;
 
+import org.apache.lucene.store.IndexInput;
+
+import java.io.IOException;
 import java.util.Objects;
 
 public abstract class ESVectorizationProvider {
@@ -24,6 +27,9 @@ public abstract class ESVectorizationProvider {
 
     public abstract ESVectorUtilSupport getVectorUtilSupport();
 
+    /** Create a new {@link ES91OSQVectorsScorer} for the given {@link IndexInput}. */
+    public abstract ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException;
+
     // visible for tests
     static ESVectorizationProvider lookup(boolean testMode) {
         return new DefaultESVectorizationProvider();

+ 5 - 0
libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java

@@ -9,10 +9,12 @@
 
 package org.elasticsearch.simdvec.internal.vectorization;
 
+import org.apache.lucene.store.IndexInput;
 import org.apache.lucene.util.Constants;
 import org.elasticsearch.logging.LogManager;
 import org.elasticsearch.logging.Logger;
 
+import java.io.IOException;
 import java.util.Locale;
 import java.util.Objects;
 import java.util.Optional;
@@ -32,6 +34,9 @@ public abstract class ESVectorizationProvider {
 
     public abstract ESVectorUtilSupport getVectorUtilSupport();
 
+    /** Create a new {@link ES91OSQVectorsScorer} for the given {@link IndexInput}. */
+    public abstract ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException;
+
     // visible for tests
     static ESVectorizationProvider lookup(boolean testMode) {
         final int runtimeVersion = Runtime.version().feature();

+ 452 - 0
libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java

@@ -0,0 +1,452 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+package org.elasticsearch.simdvec.internal.vectorization;
+
+import jdk.incubator.vector.ByteVector;
+import jdk.incubator.vector.FloatVector;
+import jdk.incubator.vector.IntVector;
+import jdk.incubator.vector.LongVector;
+import jdk.incubator.vector.ShortVector;
+import jdk.incubator.vector.VectorOperators;
+import jdk.incubator.vector.VectorSpecies;
+
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.util.VectorUtil;
+import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
+
+import java.io.IOException;
+import java.lang.foreign.MemorySegment;
+import java.nio.ByteOrder;
+
+import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
+import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
+
+/** Panamized scorer for quantized vectors stored as an {@link IndexInput}. */
+public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScorer {
+
+    private static final VectorSpecies<Integer> INT_SPECIES_128 = IntVector.SPECIES_128;
+
+    private static final VectorSpecies<Long> LONG_SPECIES_128 = LongVector.SPECIES_128;
+    private static final VectorSpecies<Long> LONG_SPECIES_256 = LongVector.SPECIES_256;
+
+    private static final VectorSpecies<Byte> BYTE_SPECIES_128 = ByteVector.SPECIES_128;
+    private static final VectorSpecies<Byte> BYTE_SPECIES_256 = ByteVector.SPECIES_256;
+
+    private static final VectorSpecies<Short> SHORT_SPECIES_128 = ShortVector.SPECIES_128;
+    private static final VectorSpecies<Short> SHORT_SPECIES_256 = ShortVector.SPECIES_256;
+
+    private static final VectorSpecies<Float> FLOAT_SPECIES_128 = FloatVector.SPECIES_128;
+    private static final VectorSpecies<Float> FLOAT_SPECIES_256 = FloatVector.SPECIES_256;
+
+    private final MemorySegment memorySegment;
+
+    public MemorySegmentES91OSQVectorsScorer(IndexInput in, int dimensions, MemorySegment memorySegment) {
+        super(in, dimensions);
+        this.memorySegment = memorySegment;
+    }
+
+    @Override
+    public long quantizeScore(byte[] q) throws IOException {
+        assert q.length == length * 4;
+        // 128 / 8 == 16
+        if (length >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) {
+            if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) {
+                return quantizeScore256(q);
+            } else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) {
+                return quantizeScore128(q);
+            }
+        }
+        return super.quantizeScore(q);
+    }
+
+    private long quantizeScore256(byte[] q) throws IOException {
+        long subRet0 = 0;
+        long subRet1 = 0;
+        long subRet2 = 0;
+        long subRet3 = 0;
+        int i = 0;
+        long offset = in.getFilePointer();
+        if (length >= ByteVector.SPECIES_256.vectorByteSize() * 2) {
+            int limit = ByteVector.SPECIES_256.loopBound(length);
+            var sum0 = LongVector.zero(LONG_SPECIES_256);
+            var sum1 = LongVector.zero(LONG_SPECIES_256);
+            var sum2 = LongVector.zero(LONG_SPECIES_256);
+            var sum3 = LongVector.zero(LONG_SPECIES_256);
+            for (; i < limit; i += ByteVector.SPECIES_256.length(), offset += LONG_SPECIES_256.vectorByteSize()) {
+                var vq0 = ByteVector.fromArray(BYTE_SPECIES_256, q, i).reinterpretAsLongs();
+                var vq1 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length).reinterpretAsLongs();
+                var vq2 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 2).reinterpretAsLongs();
+                var vq3 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 3).reinterpretAsLongs();
+                var vd = LongVector.fromMemorySegment(LONG_SPECIES_256, memorySegment, offset, ByteOrder.LITTLE_ENDIAN);
+                sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT));
+                sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT));
+                sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT));
+                sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT));
+            }
+            subRet0 += sum0.reduceLanes(VectorOperators.ADD);
+            subRet1 += sum1.reduceLanes(VectorOperators.ADD);
+            subRet2 += sum2.reduceLanes(VectorOperators.ADD);
+            subRet3 += sum3.reduceLanes(VectorOperators.ADD);
+        }
+
+        if (length - i >= ByteVector.SPECIES_128.vectorByteSize()) {
+            var sum0 = LongVector.zero(LONG_SPECIES_128);
+            var sum1 = LongVector.zero(LONG_SPECIES_128);
+            var sum2 = LongVector.zero(LONG_SPECIES_128);
+            var sum3 = LongVector.zero(LONG_SPECIES_128);
+            int limit = ByteVector.SPECIES_128.loopBound(length);
+            for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += LONG_SPECIES_128.vectorByteSize()) {
+                var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsLongs();
+                var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsLongs();
+                var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsLongs();
+                var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 3).reinterpretAsLongs();
+                var vd = LongVector.fromMemorySegment(LONG_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN);
+                sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT));
+                sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT));
+                sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT));
+                sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT));
+            }
+            subRet0 += sum0.reduceLanes(VectorOperators.ADD);
+            subRet1 += sum1.reduceLanes(VectorOperators.ADD);
+            subRet2 += sum2.reduceLanes(VectorOperators.ADD);
+            subRet3 += sum3.reduceLanes(VectorOperators.ADD);
+        }
+        // tail as bytes
+        in.seek(offset);
+        for (; i < length; i++) {
+            int dValue = in.readByte() & 0xFF;
+            subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF);
+            subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF);
+            subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF);
+            subRet3 += Integer.bitCount((q[i + 3 * length] & dValue) & 0xFF);
+        }
+        return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
+    }
+
+    private long quantizeScore128(byte[] q) throws IOException {
+        long subRet0 = 0;
+        long subRet1 = 0;
+        long subRet2 = 0;
+        long subRet3 = 0;
+        int i = 0;
+        long offset = in.getFilePointer();
+
+        var sum0 = IntVector.zero(INT_SPECIES_128);
+        var sum1 = IntVector.zero(INT_SPECIES_128);
+        var sum2 = IntVector.zero(INT_SPECIES_128);
+        var sum3 = IntVector.zero(INT_SPECIES_128);
+        int limit = ByteVector.SPECIES_128.loopBound(length);
+        for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += INT_SPECIES_128.vectorByteSize()) {
+            var vd = IntVector.fromMemorySegment(INT_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN);
+            var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsInts();
+            var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsInts();
+            var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsInts();
+            var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 3).reinterpretAsInts();
+            sum0 = sum0.add(vd.and(vq0).lanewise(VectorOperators.BIT_COUNT));
+            sum1 = sum1.add(vd.and(vq1).lanewise(VectorOperators.BIT_COUNT));
+            sum2 = sum2.add(vd.and(vq2).lanewise(VectorOperators.BIT_COUNT));
+            sum3 = sum3.add(vd.and(vq3).lanewise(VectorOperators.BIT_COUNT));
+        }
+        subRet0 += sum0.reduceLanes(VectorOperators.ADD);
+        subRet1 += sum1.reduceLanes(VectorOperators.ADD);
+        subRet2 += sum2.reduceLanes(VectorOperators.ADD);
+        subRet3 += sum3.reduceLanes(VectorOperators.ADD);
+        // tail as bytes
+        in.seek(offset);
+        for (; i < length; i++) {
+            int dValue = in.readByte() & 0xFF;
+            subRet0 += Integer.bitCount((dValue & q[i]) & 0xFF);
+            subRet1 += Integer.bitCount((dValue & q[i + length]) & 0xFF);
+            subRet2 += Integer.bitCount((dValue & q[i + 2 * length]) & 0xFF);
+            subRet3 += Integer.bitCount((dValue & q[i + 3 * length]) & 0xFF);
+        }
+        return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
+    }
+
+    @Override
+    public void quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOException {
+        assert q.length == length * 4;
+        // 128 / 8 == 16
+        if (length >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) {
+            if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) {
+                quantizeScore256Bulk(q, count, scores);
+                return;
+            } else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) {
+                quantizeScore128Bulk(q, count, scores);
+                return;
+            }
+        }
+        super.quantizeScoreBulk(q, count, scores);
+    }
+
+    private void quantizeScore128Bulk(byte[] q, int count, float[] scores) throws IOException {
+        for (int iter = 0; iter < count; iter++) {
+            long subRet0 = 0;
+            long subRet1 = 0;
+            long subRet2 = 0;
+            long subRet3 = 0;
+            int i = 0;
+            long offset = in.getFilePointer();
+
+            var sum0 = IntVector.zero(INT_SPECIES_128);
+            var sum1 = IntVector.zero(INT_SPECIES_128);
+            var sum2 = IntVector.zero(INT_SPECIES_128);
+            var sum3 = IntVector.zero(INT_SPECIES_128);
+            int limit = ByteVector.SPECIES_128.loopBound(length);
+            for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += INT_SPECIES_128.vectorByteSize()) {
+                var vd = IntVector.fromMemorySegment(INT_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN);
+                var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsInts();
+                var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsInts();
+                var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsInts();
+                var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 3).reinterpretAsInts();
+                sum0 = sum0.add(vd.and(vq0).lanewise(VectorOperators.BIT_COUNT));
+                sum1 = sum1.add(vd.and(vq1).lanewise(VectorOperators.BIT_COUNT));
+                sum2 = sum2.add(vd.and(vq2).lanewise(VectorOperators.BIT_COUNT));
+                sum3 = sum3.add(vd.and(vq3).lanewise(VectorOperators.BIT_COUNT));
+            }
+            subRet0 += sum0.reduceLanes(VectorOperators.ADD);
+            subRet1 += sum1.reduceLanes(VectorOperators.ADD);
+            subRet2 += sum2.reduceLanes(VectorOperators.ADD);
+            subRet3 += sum3.reduceLanes(VectorOperators.ADD);
+            // tail as bytes
+            in.seek(offset);
+            for (; i < length; i++) {
+                int dValue = in.readByte() & 0xFF;
+                subRet0 += Integer.bitCount((dValue & q[i]) & 0xFF);
+                subRet1 += Integer.bitCount((dValue & q[i + length]) & 0xFF);
+                subRet2 += Integer.bitCount((dValue & q[i + 2 * length]) & 0xFF);
+                subRet3 += Integer.bitCount((dValue & q[i + 3 * length]) & 0xFF);
+            }
+            scores[iter] = subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
+        }
+    }
+
+    private void quantizeScore256Bulk(byte[] q, int count, float[] scores) throws IOException {
+        for (int iter = 0; iter < count; iter++) {
+            long subRet0 = 0;
+            long subRet1 = 0;
+            long subRet2 = 0;
+            long subRet3 = 0;
+            int i = 0;
+            long offset = in.getFilePointer();
+            if (length >= ByteVector.SPECIES_256.vectorByteSize() * 2) {
+                int limit = ByteVector.SPECIES_256.loopBound(length);
+                var sum0 = LongVector.zero(LONG_SPECIES_256);
+                var sum1 = LongVector.zero(LONG_SPECIES_256);
+                var sum2 = LongVector.zero(LONG_SPECIES_256);
+                var sum3 = LongVector.zero(LONG_SPECIES_256);
+                for (; i < limit; i += ByteVector.SPECIES_256.length(), offset += LONG_SPECIES_256.vectorByteSize()) {
+                    var vq0 = ByteVector.fromArray(BYTE_SPECIES_256, q, i).reinterpretAsLongs();
+                    var vq1 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length).reinterpretAsLongs();
+                    var vq2 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 2).reinterpretAsLongs();
+                    var vq3 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 3).reinterpretAsLongs();
+                    var vd = LongVector.fromMemorySegment(LONG_SPECIES_256, memorySegment, offset, ByteOrder.LITTLE_ENDIAN);
+                    sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT));
+                    sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT));
+                    sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT));
+                    sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT));
+                }
+                subRet0 += sum0.reduceLanes(VectorOperators.ADD);
+                subRet1 += sum1.reduceLanes(VectorOperators.ADD);
+                subRet2 += sum2.reduceLanes(VectorOperators.ADD);
+                subRet3 += sum3.reduceLanes(VectorOperators.ADD);
+            }
+
+            if (length - i >= ByteVector.SPECIES_128.vectorByteSize()) {
+                var sum0 = LongVector.zero(LONG_SPECIES_128);
+                var sum1 = LongVector.zero(LONG_SPECIES_128);
+                var sum2 = LongVector.zero(LONG_SPECIES_128);
+                var sum3 = LongVector.zero(LONG_SPECIES_128);
+                int limit = ByteVector.SPECIES_128.loopBound(length);
+                for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += LONG_SPECIES_128.vectorByteSize()) {
+                    var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsLongs();
+                    var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsLongs();
+                    var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsLongs();
+                    var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 3).reinterpretAsLongs();
+                    var vd = LongVector.fromMemorySegment(LONG_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN);
+                    sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT));
+                    sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT));
+                    sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT));
+                    sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT));
+                }
+                subRet0 += sum0.reduceLanes(VectorOperators.ADD);
+                subRet1 += sum1.reduceLanes(VectorOperators.ADD);
+                subRet2 += sum2.reduceLanes(VectorOperators.ADD);
+                subRet3 += sum3.reduceLanes(VectorOperators.ADD);
+            }
+            // tail as bytes
+            in.seek(offset);
+            for (; i < length; i++) {
+                int dValue = in.readByte() & 0xFF;
+                subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF);
+                subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF);
+                subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF);
+                subRet3 += Integer.bitCount((q[i + 3 * length] & dValue) & 0xFF);
+            }
+            scores[iter] = subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
+        }
+    }
+
+    @Override
+    public void scoreBulk(
+        byte[] q,
+        OptimizedScalarQuantizer.QuantizationResult queryCorrections,
+        VectorSimilarityFunction similarityFunction,
+        float centroidDp,
+        float[] scores
+    ) throws IOException {
+        assert q.length == length * 4;
+        // 128 / 8 == 16
+        if (length >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) {
+            if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) {
+                score256Bulk(q, queryCorrections, similarityFunction, centroidDp, scores);
+                return;
+            } else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) {
+                score128Bulk(q, queryCorrections, similarityFunction, centroidDp, scores);
+                return;
+            }
+        }
+        super.scoreBulk(q, queryCorrections, similarityFunction, centroidDp, scores);
+    }
+
+    private void score128Bulk(
+        byte[] q,
+        OptimizedScalarQuantizer.QuantizationResult queryCorrections,
+        VectorSimilarityFunction similarityFunction,
+        float centroidDp,
+        float[] scores
+    ) throws IOException {
+        quantizeScore128Bulk(q, BULK_SIZE, scores);
+        int limit = FLOAT_SPECIES_128.loopBound(BULK_SIZE);
+        int i = 0;
+        long offset = in.getFilePointer();
+        float ay = queryCorrections.lowerInterval();
+        float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
+        float y1 = queryCorrections.quantizedComponentSum();
+        for (; i < limit; i += FLOAT_SPECIES_128.length()) {
+            var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_128, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
+            var lx = FloatVector.fromMemorySegment(
+                FLOAT_SPECIES_128,
+                memorySegment,
+                offset + 4 * BULK_SIZE + i * Float.BYTES,
+                ByteOrder.LITTLE_ENDIAN
+            ).sub(ax);
+            var targetComponentSums = ShortVector.fromMemorySegment(
+                SHORT_SPECIES_128,
+                memorySegment,
+                offset + 8 * BULK_SIZE + i * Short.BYTES,
+                ByteOrder.LITTLE_ENDIAN
+            ).convert(VectorOperators.S2I, 0).reinterpretAsInts().and(0xffff).convert(VectorOperators.I2F, 0);
+            var additionalCorrections = FloatVector.fromMemorySegment(
+                FLOAT_SPECIES_128,
+                memorySegment,
+                offset + 10 * BULK_SIZE + i * Float.BYTES,
+                ByteOrder.LITTLE_ENDIAN
+            );
+            var qcDist = FloatVector.fromArray(FLOAT_SPECIES_128, scores, i);
+            // ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly *
+            // qcDist;
+            var res1 = ax.mul(ay).mul(dimensions);
+            var res2 = lx.mul(ay).mul(targetComponentSums);
+            var res3 = ax.mul(ly).mul(y1);
+            var res4 = lx.mul(ly).mul(qcDist);
+            var res = res1.add(res2).add(res3).add(res4);
+            // For euclidean, we need to invert the score and apply the additional correction, which is
+            // assumed to be the squared l2norm of the centroid centered vectors.
+            if (similarityFunction == EUCLIDEAN) {
+                res = res.mul(-2).add(additionalCorrections).add(queryCorrections.additionalCorrection()).add(1f);
+                res = FloatVector.broadcast(FLOAT_SPECIES_128, 1).div(res).max(0);
+                res.intoArray(scores, i);
+            } else {
+                // For cosine and max inner product, we need to apply the additional correction, which is
+                // assumed to be the non-centered dot-product between the vector and the centroid
+                res = res.add(queryCorrections.additionalCorrection()).add(additionalCorrections).sub(centroidDp);
+                if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
+                    res.intoArray(scores, i);
+                    // not sure how to do it better
+                    for (int j = 0; j < FLOAT_SPECIES_128.length(); j++) {
+                        scores[i + j] = VectorUtil.scaleMaxInnerProductScore(scores[i + j]);
+                    }
+                } else {
+                    res = res.add(1f).mul(0.5f).max(0);
+                    res.intoArray(scores, i);
+                }
+            }
+        }
+        in.seek(offset + 14L * BULK_SIZE);
+    }
+
+    private void score256Bulk(
+        byte[] q,
+        OptimizedScalarQuantizer.QuantizationResult queryCorrections,
+        VectorSimilarityFunction similarityFunction,
+        float centroidDp,
+        float[] scores
+    ) throws IOException {
+        quantizeScore256Bulk(q, BULK_SIZE, scores);
+        int limit = FLOAT_SPECIES_256.loopBound(BULK_SIZE);
+        int i = 0;
+        long offset = in.getFilePointer();
+        float ay = queryCorrections.lowerInterval();
+        float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
+        float y1 = queryCorrections.quantizedComponentSum();
+        for (; i < limit; i += FLOAT_SPECIES_256.length()) {
+            var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_256, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
+            var lx = FloatVector.fromMemorySegment(
+                FLOAT_SPECIES_256,
+                memorySegment,
+                offset + 4 * BULK_SIZE + i * Float.BYTES,
+                ByteOrder.LITTLE_ENDIAN
+            ).sub(ax);
+            var targetComponentSums = ShortVector.fromMemorySegment(
+                SHORT_SPECIES_256,
+                memorySegment,
+                offset + 8 * BULK_SIZE + i * Short.BYTES,
+                ByteOrder.LITTLE_ENDIAN
+            ).convert(VectorOperators.S2I, 0).reinterpretAsInts().and(0xffff).convert(VectorOperators.I2F, 0);
+            var additionalCorrections = FloatVector.fromMemorySegment(
+                FLOAT_SPECIES_256,
+                memorySegment,
+                offset + 10 * BULK_SIZE + i * Float.BYTES,
+                ByteOrder.LITTLE_ENDIAN
+            );
+            var qcDist = FloatVector.fromArray(FLOAT_SPECIES_256, scores, i);
+            // ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly *
+            // qcDist;
+            var res1 = ax.mul(ay).mul(dimensions);
+            var res2 = lx.mul(ay).mul(targetComponentSums);
+            var res3 = ax.mul(ly).mul(y1);
+            var res4 = lx.mul(ly).mul(qcDist);
+            var res = res1.add(res2).add(res3).add(res4);
+            // For euclidean, we need to invert the score and apply the additional correction, which is
+            // assumed to be the squared l2norm of the centroid centered vectors.
+            if (similarityFunction == EUCLIDEAN) {
+                res = res.mul(-2).add(additionalCorrections).add(queryCorrections.additionalCorrection()).add(1f);
+                res = FloatVector.broadcast(FLOAT_SPECIES_256, 1).div(res).max(0);
+                res.intoArray(scores, i);
+            } else {
+                // For cosine and max inner product, we need to apply the additional correction, which is
+                // assumed to be the non-centered dot-product between the vector and the centroid
+                res = res.add(queryCorrections.additionalCorrection()).add(additionalCorrections).sub(centroidDp);
+                if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
+                    res.intoArray(scores, i);
+                    // not sure how to do it better
+                    for (int j = 0; j < FLOAT_SPECIES_256.length(); j++) {
+                        scores[i + j] = VectorUtil.scaleMaxInnerProductScore(scores[i + j]);
+                    }
+                } else {
+                    res = res.add(1f).mul(0.5f).max(0);
+                    res.intoArray(scores, i);
+                }
+            }
+        }
+        in.seek(offset + 14L * BULK_SIZE);
+    }
+}

+ 17 - 0
libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java

@@ -9,6 +9,12 @@
 
 package org.elasticsearch.simdvec.internal.vectorization;
 
+import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.store.MemorySegmentAccessInput;
+
+import java.io.IOException;
+import java.lang.foreign.MemorySegment;
+
 final class PanamaESVectorizationProvider extends ESVectorizationProvider {
 
     private final ESVectorUtilSupport vectorUtilSupport;
@@ -21,4 +27,15 @@ final class PanamaESVectorizationProvider extends ESVectorizationProvider {
     public ESVectorUtilSupport getVectorUtilSupport() {
         return vectorUtilSupport;
     }
+
+    @Override
+    public ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException {
+        if (PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS && input instanceof MemorySegmentAccessInput msai) {
+            MemorySegment ms = msai.segmentSliceOrNull(0, input.length());
+            if (ms != null) {
+                return new MemorySegmentES91OSQVectorsScorer(input, dimension, ms);
+            }
+        }
+        return new ES91OSQVectorsScorer(input, dimension);
+    }
 }

+ 126 - 0
libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java

@@ -0,0 +1,126 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.simdvec.internal.vectorization;
+
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.store.IOContext;
+import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.store.IndexOutput;
+import org.apache.lucene.store.MMapDirectory;
+import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
+
+import static org.hamcrest.Matchers.lessThan;
+
+public class ES91OSQVectorScorerTests extends BaseVectorizationTests {
+
+    public void testQuantizeScore() throws Exception {
+        final int dimensions = random().nextInt(1, 2000);
+        final int length = OptimizedScalarQuantizer.discretize(dimensions, 64) / 8;
+        final int numVectors = random().nextInt(1, 100);
+        final byte[] vector = new byte[length];
+        try (Directory dir = new MMapDirectory(createTempDir())) {
+            try (IndexOutput out = dir.createOutput("tests.bin", IOContext.DEFAULT)) {
+                for (int i = 0; i < numVectors; i++) {
+                    random().nextBytes(vector);
+                    out.writeBytes(vector, 0, length);
+                }
+            }
+            final byte[] query = new byte[4 * length];
+            random().nextBytes(query);
+            try (IndexInput in = dir.openInput("tests.bin", IOContext.DEFAULT)) {
+                // Work on a slice that has just the right number of bytes to make the test fail with an
+                // index-out-of-bounds in case the implementation reads more than the allowed number of
+                // padding bytes.
+                final IndexInput slice = in.slice("test", 0, (long) length * numVectors);
+                final ES91OSQVectorsScorer defaultScorer = defaultProvider().newES91OSQVectorsScorer(slice, dimensions);
+                final ES91OSQVectorsScorer panamaScorer = maybePanamaProvider().newES91OSQVectorsScorer(in, dimensions);
+                for (int i = 0; i < numVectors; i++) {
+                    assertEquals(defaultScorer.quantizeScore(query), panamaScorer.quantizeScore(query));
+                    assertEquals(in.getFilePointer(), slice.getFilePointer());
+                }
+                assertEquals((long) length * numVectors, slice.getFilePointer());
+            }
+        }
+    }
+
+    public void testScore() throws Exception {
+        final int maxDims = 512;
+        final int dimensions = random().nextInt(1, maxDims);
+        final int length = OptimizedScalarQuantizer.discretize(dimensions, 64) / 8;
+        final int numVectors = ES91OSQVectorsScorer.BULK_SIZE * random().nextInt(1, 10);
+        final byte[] vector = new byte[length];
+        int padding = random().nextInt(100);
+        byte[] paddingBytes = new byte[padding];
+        try (Directory dir = new MMapDirectory(createTempDir())) {
+            try (IndexOutput out = dir.createOutput("testScore.bin", IOContext.DEFAULT)) {
+                random().nextBytes(paddingBytes);
+                out.writeBytes(paddingBytes, 0, padding);
+                for (int i = 0; i < numVectors; i++) {
+                    random().nextBytes(vector);
+                    out.writeBytes(vector, 0, length);
+                    float lower = random().nextFloat();
+                    float upper = random().nextFloat() + lower / 2;
+                    float additionalCorrection = random().nextFloat();
+                    int targetComponentSum = randomIntBetween(0, dimensions / 2);
+                    out.writeInt(Float.floatToIntBits(lower));
+                    out.writeInt(Float.floatToIntBits(upper));
+                    out.writeShort((short) targetComponentSum);
+                    out.writeInt(Float.floatToIntBits(additionalCorrection));
+                }
+            }
+            final byte[] query = new byte[4 * length];
+            random().nextBytes(query);
+            float lower = random().nextFloat();
+            OptimizedScalarQuantizer.QuantizationResult result = new OptimizedScalarQuantizer.QuantizationResult(
+                lower,
+                random().nextFloat() + lower / 2,
+                random().nextFloat(),
+                randomIntBetween(0, dimensions * 2)
+            );
+            final float centroidDp = random().nextFloat();
+            final float[] scores1 = new float[ES91OSQVectorsScorer.BULK_SIZE];
+            final float[] scores2 = new float[ES91OSQVectorsScorer.BULK_SIZE];
+            for (VectorSimilarityFunction similarityFunction : VectorSimilarityFunction.values()) {
+                try (IndexInput in = dir.openInput("testScore.bin", IOContext.DEFAULT)) {
+                    in.seek(padding);
+                    assertEquals(in.length(), padding + (long) numVectors * (length + 14));
+                    // Work on a slice that has just the right number of bytes to make the test fail with an
+                    // index-out-of-bounds in case the implementation reads more than the allowed number of
+                    // padding bytes.
+                    for (int i = 0; i < numVectors; i += ES91OSQVectorsScorer.BULK_SIZE) {
+                        final IndexInput slice = in.slice(
+                            "test",
+                            in.getFilePointer(),
+                            (long) (length + 14) * ES91OSQVectorsScorer.BULK_SIZE
+                        );
+                        final ES91OSQVectorsScorer defaultScorer = defaultProvider().newES91OSQVectorsScorer(slice, dimensions);
+                        final ES91OSQVectorsScorer panamaScorer = maybePanamaProvider().newES91OSQVectorsScorer(in, dimensions);
+                        defaultScorer.scoreBulk(query, result, similarityFunction, centroidDp, scores1);
+                        panamaScorer.scoreBulk(query, result, similarityFunction, centroidDp, scores2);
+                        for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
+                            if (scores1[j] > (maxDims * Short.MAX_VALUE)) {
+                                int diff = (int) (scores1[j] - scores2[j]);
+                                assertThat("defaultScores: " + scores1[j] + " bulkScores: " + scores2[j], Math.abs(diff), lessThan(65));
+                            } else if (scores1[j] > (maxDims * Byte.MAX_VALUE)) {
+                                int diff = (int) (scores1[j] - scores2[j]);
+                                assertThat("defaultScores: " + scores1[j] + " bulkScores: " + scores2[j], Math.abs(diff), lessThan(9));
+                            } else {
+                                assertEquals(scores1[j], scores2[j], 1e-2f);
+                            }
+                        }
+                        assertEquals(((long) (ES91OSQVectorsScorer.BULK_SIZE) * (length + 14)), slice.getFilePointer());
+                        assertEquals(padding + ((long) (i + ES91OSQVectorsScorer.BULK_SIZE) * (length + 14)), in.getFilePointer());
+                    }
+                }
+            }
+        }
+    }
+}