Browse Source

Introduce an int4 off-heap vector scorer (#129824)

* Introduce an int4 off-heap vector scorer

* iter

* Update server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java

Co-authored-by: Benjamin Trent <ben.w.trent@gmail.com>

---------

Co-authored-by: Benjamin Trent <ben.w.trent@gmail.com>
Ignacio Vera 3 months ago
parent
commit
ffea6ca2bf

+ 123 - 0
benchmarks/src/main/java/org/elasticsearch/benchmark/vector/Int4ScorerBenchmark.java

@@ -0,0 +1,123 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+package org.elasticsearch.benchmark.vector;
+
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.store.IOContext;
+import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.store.IndexOutput;
+import org.apache.lucene.store.MMapDirectory;
+import org.apache.lucene.util.VectorUtil;
+import org.elasticsearch.common.logging.LogConfigurator;
+import org.elasticsearch.core.IOUtils;
+import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
+import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
+import org.openjdk.jmh.annotations.Benchmark;
+import org.openjdk.jmh.annotations.BenchmarkMode;
+import org.openjdk.jmh.annotations.Fork;
+import org.openjdk.jmh.annotations.Measurement;
+import org.openjdk.jmh.annotations.Mode;
+import org.openjdk.jmh.annotations.OutputTimeUnit;
+import org.openjdk.jmh.annotations.Param;
+import org.openjdk.jmh.annotations.Scope;
+import org.openjdk.jmh.annotations.Setup;
+import org.openjdk.jmh.annotations.State;
+import org.openjdk.jmh.annotations.TearDown;
+import org.openjdk.jmh.annotations.Warmup;
+import org.openjdk.jmh.infra.Blackhole;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.util.concurrent.ThreadLocalRandom;
+import java.util.concurrent.TimeUnit;
+
+@BenchmarkMode(Mode.Throughput)
+@OutputTimeUnit(TimeUnit.MILLISECONDS)
+@State(Scope.Benchmark)
+// first iteration is complete garbage, so make sure we really warmup
+@Warmup(iterations = 4, time = 1)
+// real iterations. not useful to spend tons of time here, better to fork more
+@Measurement(iterations = 5, time = 1)
+// engage some noise reduction
+@Fork(value = 1)
+public class Int4ScorerBenchmark {
+
+    static {
+        LogConfigurator.configureESLogging(); // native access requires logging to be initialized
+    }
+
+    @Param({ "384", "702", "1024" })
+    int dims;
+
+    int numVectors = 200;
+    int numQueries = 10;
+
+    byte[] scratch;
+    byte[][] binaryVectors;
+    byte[][] binaryQueries;
+
+    ES91Int4VectorsScorer scorer;
+    Directory dir;
+    IndexInput in;
+
+    @Setup
+    public void setup() throws IOException {
+        binaryVectors = new byte[numVectors][dims];
+        dir = new MMapDirectory(Files.createTempDirectory("vectorData"));
+        try (IndexOutput out = dir.createOutput("vectors", IOContext.DEFAULT)) {
+            for (byte[] binaryVector : binaryVectors) {
+                for (int i = 0; i < dims; i++) {
+                    // 4-bit quantization
+                    binaryVector[i] = (byte) ThreadLocalRandom.current().nextInt(16);
+                }
+                out.writeBytes(binaryVector, 0, binaryVector.length);
+            }
+        }
+
+        in = dir.openInput("vectors", IOContext.DEFAULT);
+        binaryQueries = new byte[numVectors][dims];
+        for (byte[] binaryVector : binaryVectors) {
+            for (int i = 0; i < dims; i++) {
+                // 4-bit quantization
+                binaryVector[i] = (byte) ThreadLocalRandom.current().nextInt(16);
+            }
+        }
+
+        scratch = new byte[dims];
+        scorer = ESVectorizationProvider.getInstance().newES91Int4VectorsScorer(in, dims);
+    }
+
+    @TearDown
+    public void teardown() throws IOException {
+        IOUtils.close(dir, in);
+    }
+
+    @Benchmark
+    @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
+    public void scoreFromArray(Blackhole bh) throws IOException {
+        for (int j = 0; j < numQueries; j++) {
+            in.seek(0);
+            for (int i = 0; i < numVectors; i++) {
+                in.readBytes(scratch, 0, dims);
+                bh.consume(VectorUtil.int4DotProduct(binaryQueries[j], scratch));
+            }
+        }
+    }
+
+    @Benchmark
+    @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
+    public void scoreFromMemorySegmentOnlyVector(Blackhole bh) throws IOException {
+        for (int j = 0; j < numQueries; j++) {
+            in.seek(0);
+            for (int i = 0; i < numVectors; i++) {
+                bh.consume(scorer.int4DotProduct(binaryQueries[j]));
+            }
+        }
+    }
+}

+ 43 - 0
libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91Int4VectorsScorer.java

@@ -0,0 +1,43 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+package org.elasticsearch.simdvec;
+
+import org.apache.lucene.store.IndexInput;
+
+import java.io.IOException;
+
+/** Scorer for quantized vectors stored as an {@link IndexInput}.
+ * <p>
+ * Similar to {@link org.apache.lucene.util.VectorUtil#int4DotProduct(byte[], byte[])} but
+ * one value is read directly from an {@link IndexInput}.
+ *
+ * */
+public class ES91Int4VectorsScorer {
+
+    /** The wrapper {@link IndexInput}. */
+    protected final IndexInput in;
+    protected final int dimensions;
+    protected byte[] scratch;
+
+    /** Sole constructor, called by sub-classes. */
+    public ES91Int4VectorsScorer(IndexInput in, int dimensions) {
+        this.in = in;
+        this.dimensions = dimensions;
+        scratch = new byte[dimensions];
+    }
+
+    public long int4DotProduct(byte[] b) throws IOException {
+        in.readBytes(scratch, 0, dimensions);
+        int total = 0;
+        for (int i = 0; i < dimensions; i++) {
+            total += scratch[i] * b[i];
+        }
+        return total;
+    }
+}

+ 4 - 0
libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java

@@ -47,6 +47,10 @@ public class ESVectorUtil {
         return ESVectorizationProvider.getInstance().newES91OSQVectorsScorer(input, dimension);
     }
 
+    public static ES91Int4VectorsScorer getES91Int4VectorsScorer(IndexInput input, int dimension) throws IOException {
+        return ESVectorizationProvider.getInstance().newES91Int4VectorsScorer(input, dimension);
+    }
+
     public static long ipByteBinByte(byte[] q, byte[] d) {
         if (q.length != d.length * B_QUERY) {
             throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + B_QUERY + " x " + d.length);

+ 6 - 0
libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java

@@ -10,6 +10,7 @@
 package org.elasticsearch.simdvec.internal.vectorization;
 
 import org.apache.lucene.store.IndexInput;
+import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
 import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
 
 import java.io.IOException;
@@ -30,4 +31,9 @@ final class DefaultESVectorizationProvider extends ESVectorizationProvider {
     public ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException {
         return new ES91OSQVectorsScorer(input, dimension);
     }
+
+    @Override
+    public ES91Int4VectorsScorer newES91Int4VectorsScorer(IndexInput input, int dimension) throws IOException {
+        return new ES91Int4VectorsScorer(input, dimension);
+    }
 }

+ 4 - 0
libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java

@@ -10,6 +10,7 @@
 package org.elasticsearch.simdvec.internal.vectorization;
 
 import org.apache.lucene.store.IndexInput;
+import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
 import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
 
 import java.io.IOException;
@@ -31,6 +32,9 @@ public abstract class ESVectorizationProvider {
     /** Create a new {@link ES91OSQVectorsScorer} for the given {@link IndexInput}. */
     public abstract ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException;
 
+    /** Create a new {@link ES91Int4VectorsScorer} for the given {@link IndexInput}. */
+    public abstract ES91Int4VectorsScorer newES91Int4VectorsScorer(IndexInput input, int dimension) throws IOException;
+
     // visible for tests
     static ESVectorizationProvider lookup(boolean testMode) {
         return new DefaultESVectorizationProvider();

+ 4 - 0
libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java

@@ -13,6 +13,7 @@ import org.apache.lucene.store.IndexInput;
 import org.apache.lucene.util.Constants;
 import org.elasticsearch.logging.LogManager;
 import org.elasticsearch.logging.Logger;
+import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
 import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
 
 import java.io.IOException;
@@ -38,6 +39,9 @@ public abstract class ESVectorizationProvider {
     /** Create a new {@link ES91OSQVectorsScorer} for the given {@link IndexInput}. */
     public abstract ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException;
 
+    /** Create a new {@link ES91Int4VectorsScorer} for the given {@link IndexInput}. */
+    public abstract ES91Int4VectorsScorer newES91Int4VectorsScorer(IndexInput input, int dimension) throws IOException;
+
     // visible for tests
     static ESVectorizationProvider lookup(boolean testMode) {
         final int runtimeVersion = Runtime.version().feature();

+ 191 - 0
libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91Int4VectorsScorer.java

@@ -0,0 +1,191 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+package org.elasticsearch.simdvec.internal.vectorization;
+
+import jdk.incubator.vector.ByteVector;
+import jdk.incubator.vector.IntVector;
+import jdk.incubator.vector.ShortVector;
+import jdk.incubator.vector.Vector;
+import jdk.incubator.vector.VectorSpecies;
+
+import org.apache.lucene.store.IndexInput;
+import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
+
+import java.io.IOException;
+import java.lang.foreign.MemorySegment;
+
+import static java.nio.ByteOrder.LITTLE_ENDIAN;
+import static jdk.incubator.vector.VectorOperators.ADD;
+import static jdk.incubator.vector.VectorOperators.B2I;
+import static jdk.incubator.vector.VectorOperators.B2S;
+import static jdk.incubator.vector.VectorOperators.S2I;
+
+/** Panamized scorer for quantized vectors stored as an {@link IndexInput}.
+ * <p>
+ * Similar to {@link org.apache.lucene.util.VectorUtil#int4DotProduct(byte[], byte[])} but
+ *  one value is read directly from a {@link MemorySegment}.
+ * */
+public final class MemorySegmentES91Int4VectorsScorer extends ES91Int4VectorsScorer {
+
+    private static final VectorSpecies<Byte> BYTE_SPECIES_64 = ByteVector.SPECIES_64;
+    private static final VectorSpecies<Byte> BYTE_SPECIES_128 = ByteVector.SPECIES_128;
+
+    private static final VectorSpecies<Short> SHORT_SPECIES_128 = ShortVector.SPECIES_128;
+    private static final VectorSpecies<Short> SHORT_SPECIES_256 = ShortVector.SPECIES_256;
+
+    private static final VectorSpecies<Integer> INT_SPECIES_128 = IntVector.SPECIES_128;
+    private static final VectorSpecies<Integer> INT_SPECIES_256 = IntVector.SPECIES_256;
+    private static final VectorSpecies<Integer> INT_SPECIES_512 = IntVector.SPECIES_512;
+
+    private final MemorySegment memorySegment;
+
+    public MemorySegmentES91Int4VectorsScorer(IndexInput in, int dimensions, MemorySegment memorySegment) {
+        super(in, dimensions);
+        this.memorySegment = memorySegment;
+    }
+
+    @Override
+    public long int4DotProduct(byte[] q) throws IOException {
+        if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 512 || PanamaESVectorUtilSupport.VECTOR_BITSIZE == 256) {
+            return dotProduct(q);
+        }
+        int i = 0;
+        int res = 0;
+        if (dimensions >= 32 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) {
+            i += BYTE_SPECIES_128.loopBound(dimensions);
+            res += int4DotProductBody128(q, i);
+        }
+        in.readBytes(scratch, i, dimensions - i);
+        while (i < dimensions) {
+            res += scratch[i] * q[i++];
+        }
+        return res;
+    }
+
+    private int int4DotProductBody128(byte[] q, int limit) throws IOException {
+        int sum = 0;
+        long offset = in.getFilePointer();
+        for (int i = 0; i < limit; i += 1024) {
+            ShortVector acc0 = ShortVector.zero(SHORT_SPECIES_128);
+            ShortVector acc1 = ShortVector.zero(SHORT_SPECIES_128);
+            int innerLimit = Math.min(limit - i, 1024);
+            for (int j = 0; j < innerLimit; j += BYTE_SPECIES_128.length()) {
+                ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i + j);
+                ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i + j, LITTLE_ENDIAN);
+                ByteVector prod8 = va8.mul(vb8);
+                ShortVector prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts();
+                acc0 = acc0.add(prod16.and((short) 255));
+                va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i + j + 8);
+                vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i + j + 8, LITTLE_ENDIAN);
+                prod8 = va8.mul(vb8);
+                prod16 = prod8.convertShape(B2S, SHORT_SPECIES_128, 0).reinterpretAsShorts();
+                acc1 = acc1.add(prod16.and((short) 255));
+            }
+
+            IntVector intAcc0 = acc0.convertShape(S2I, INT_SPECIES_128, 0).reinterpretAsInts();
+            IntVector intAcc1 = acc0.convertShape(S2I, INT_SPECIES_128, 1).reinterpretAsInts();
+            IntVector intAcc2 = acc1.convertShape(S2I, INT_SPECIES_128, 0).reinterpretAsInts();
+            IntVector intAcc3 = acc1.convertShape(S2I, INT_SPECIES_128, 1).reinterpretAsInts();
+            sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reduceLanes(ADD);
+        }
+        in.seek(offset + limit);
+        return sum;
+    }
+
+    private long dotProduct(byte[] q) throws IOException {
+        int i = 0;
+        int res = 0;
+
+        // only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit
+        // vectors (256-bit on intel to dodge performance landmines)
+        if (dimensions >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) {
+            // compute vectorized dot product consistent with VPDPBUSD instruction
+            if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 512) {
+                i += BYTE_SPECIES_128.loopBound(dimensions);
+                res += dotProductBody512(q, i);
+            } else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 256) {
+                i += BYTE_SPECIES_64.loopBound(dimensions);
+                res += dotProductBody256(q, i);
+            } else {
+                // tricky: we don't have SPECIES_32, so we workaround with "overlapping read"
+                i += BYTE_SPECIES_64.loopBound(dimensions - BYTE_SPECIES_64.length());
+                res += dotProductBody128(q, i);
+            }
+        }
+        // scalar tail
+        for (; i < q.length; i++) {
+            res += in.readByte() * q[i];
+        }
+        return res;
+    }
+
+    /** vectorized dot product body (512 bit vectors) */
+    private int dotProductBody512(byte[] q, int limit) throws IOException {
+        IntVector acc = IntVector.zero(INT_SPECIES_512);
+        long offset = in.getFilePointer();
+        for (int i = 0; i < limit; i += BYTE_SPECIES_128.length()) {
+            ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_128, q, i);
+            ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_128, memorySegment, offset + i, LITTLE_ENDIAN);
+
+            // 16-bit multiply: avoid AVX-512 heavy multiply on zmm
+            Vector<Short> va16 = va8.convertShape(B2S, SHORT_SPECIES_256, 0);
+            Vector<Short> vb16 = vb8.convertShape(B2S, SHORT_SPECIES_256, 0);
+            Vector<Short> prod16 = va16.mul(vb16);
+
+            // 32-bit add
+            Vector<Integer> prod32 = prod16.convertShape(S2I, INT_SPECIES_512, 0);
+            acc = acc.add(prod32);
+        }
+
+        in.seek(offset + limit); // advance the input stream
+        // reduce
+        return acc.reduceLanes(ADD);
+    }
+
+    /** vectorized dot product body (256 bit vectors) */
+    private int dotProductBody256(byte[] q, int limit) throws IOException {
+        IntVector acc = IntVector.zero(INT_SPECIES_256);
+        long offset = in.getFilePointer();
+        for (int i = 0; i < limit; i += BYTE_SPECIES_64.length()) {
+            ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i);
+            ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i, LITTLE_ENDIAN);
+
+            // 32-bit multiply and add into accumulator
+            Vector<Integer> va32 = va8.convertShape(B2I, INT_SPECIES_256, 0);
+            Vector<Integer> vb32 = vb8.convertShape(B2I, INT_SPECIES_256, 0);
+            acc = acc.add(va32.mul(vb32));
+        }
+        in.seek(offset + limit);
+        // reduce
+        return acc.reduceLanes(ADD);
+    }
+
+    /** vectorized dot product body (128 bit vectors) */
+    private int dotProductBody128(byte[] q, int limit) throws IOException {
+        IntVector acc = IntVector.zero(INT_SPECIES_128);
+        long offset = in.getFilePointer();
+        // 4 bytes at a time (re-loading half the vector each time!)
+        for (int i = 0; i < limit; i += BYTE_SPECIES_64.length() >> 1) {
+            // load 8 bytes
+            ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i);
+            ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i, LITTLE_ENDIAN);
+
+            // process first "half" only: 16-bit multiply
+            Vector<Short> va16 = va8.convert(B2S, 0);
+            Vector<Short> vb16 = vb8.convert(B2S, 0);
+            Vector<Short> prod16 = va16.mul(vb16);
+
+            // 32-bit add
+            acc = acc.add(prod16.convertShape(S2I, INT_SPECIES_128, 0));
+        }
+        in.seek(offset + limit);
+        // reduce
+        return acc.reduceLanes(ADD);
+    }
+}

+ 12 - 0
libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java

@@ -11,6 +11,7 @@ package org.elasticsearch.simdvec.internal.vectorization;
 
 import org.apache.lucene.store.IndexInput;
 import org.apache.lucene.store.MemorySegmentAccessInput;
+import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
 import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
 
 import java.io.IOException;
@@ -39,4 +40,15 @@ final class PanamaESVectorizationProvider extends ESVectorizationProvider {
         }
         return new ES91OSQVectorsScorer(input, dimension);
     }
+
+    @Override
+    public ES91Int4VectorsScorer newES91Int4VectorsScorer(IndexInput input, int dimension) throws IOException {
+        if (PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS && input instanceof MemorySegmentAccessInput msai) {
+            MemorySegment ms = msai.segmentSliceOrNull(0, input.length());
+            if (ms != null) {
+                return new MemorySegmentES91Int4VectorsScorer(input, dimension, ms);
+            }
+        }
+        return new ES91Int4VectorsScorer(input, dimension);
+    }
 }

+ 60 - 0
libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91Int4VectorScorerTests.java

@@ -0,0 +1,60 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.simdvec.internal.vectorization;
+
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.store.IOContext;
+import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.store.IndexOutput;
+import org.apache.lucene.store.MMapDirectory;
+import org.apache.lucene.util.VectorUtil;
+import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
+
+public class ES91Int4VectorScorerTests extends BaseVectorizationTests {
+
+    public void testInt4DotProduct() throws Exception {
+        // only even dimensions are supported
+        final int dimensions = random().nextInt(1, 1000) * 2;
+        final int numVectors = random().nextInt(1, 100);
+        final byte[] vector = new byte[dimensions];
+        try (Directory dir = new MMapDirectory(createTempDir())) {
+            try (IndexOutput out = dir.createOutput("tests.bin", IOContext.DEFAULT)) {
+                for (int i = 0; i < numVectors; i++) {
+                    for (int j = 0; j < dimensions; j++) {
+                        vector[j] = (byte) random().nextInt(16); // 4-bit quantization
+                    }
+                    out.writeBytes(vector, 0, dimensions);
+                }
+            }
+            final byte[] query = new byte[dimensions];
+            for (int j = 0; j < dimensions; j++) {
+                query[j] = (byte) random().nextInt(16); // 4-bit quantization
+            }
+            try (IndexInput in = dir.openInput("tests.bin", IOContext.DEFAULT)) {
+                // Work on a slice that has just the right number of bytes to make the test fail with an
+                // index-out-of-bounds in case the implementation reads more than the allowed number of
+                // padding bytes.
+                final IndexInput slice = in.slice("test", 0, (long) dimensions * numVectors);
+                final IndexInput slice2 = in.slice("test2", 0, (long) dimensions * numVectors);
+                final ES91Int4VectorsScorer defaultScorer = defaultProvider().newES91Int4VectorsScorer(slice, dimensions);
+                final ES91Int4VectorsScorer panamaScorer = maybePanamaProvider().newES91Int4VectorsScorer(slice2, dimensions);
+                for (int i = 0; i < numVectors; i++) {
+                    in.readBytes(vector, 0, dimensions);
+                    long val = VectorUtil.int4DotProduct(vector, query);
+                    assertEquals(val, defaultScorer.int4DotProduct(query));
+                    assertEquals(val, panamaScorer.int4DotProduct(query));
+                    assertEquals(in.getFilePointer(), slice.getFilePointer());
+                    assertEquals(in.getFilePointer(), slice2.getFilePointer());
+                }
+                assertEquals((long) dimensions * numVectors, in.getFilePointer());
+            }
+        }
+    }
+}

+ 58 - 71
server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java

@@ -19,6 +19,7 @@ import org.apache.lucene.util.ArrayUtil;
 import org.apache.lucene.util.VectorUtil;
 import org.apache.lucene.util.hnsw.NeighborQueue;
 import org.elasticsearch.index.codec.vectors.reflect.OffHeapStats;
+import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
 import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
 import org.elasticsearch.simdvec.ESVectorUtil;
 
@@ -48,25 +49,23 @@ public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeap
     @Override
     CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] targetQuery)
         throws IOException {
-        FieldEntry fieldEntry = fields.get(fieldInfo.number);
-        float[] globalCentroid = fieldEntry.globalCentroid();
-        float globalCentroidDp = fieldEntry.globalCentroidDp();
-        OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
-        byte[] quantized = new byte[targetQuery.length];
-        float[] targetScratch = ArrayUtil.copyArray(targetQuery);
-        OptimizedScalarQuantizer.QuantizationResult queryParams = scalarQuantizer.scalarQuantize(
-            targetScratch,
+        final FieldEntry fieldEntry = fields.get(fieldInfo.number);
+        final float globalCentroidDp = fieldEntry.globalCentroidDp();
+        final OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
+        final byte[] quantized = new byte[targetQuery.length];
+        final OptimizedScalarQuantizer.QuantizationResult queryParams = scalarQuantizer.scalarQuantize(
+            ArrayUtil.copyArray(targetQuery),
             quantized,
             (byte) 4,
-            globalCentroid
+            fieldEntry.globalCentroid()
         );
+        final ES91Int4VectorsScorer scorer = ESVectorUtil.getES91Int4VectorsScorer(centroids, fieldInfo.getVectorDimension());
         return new CentroidQueryScorer() {
             int currentCentroid = -1;
-            private final byte[] quantizedCentroid = new byte[fieldInfo.getVectorDimension()];
             private final float[] centroid = new float[fieldInfo.getVectorDimension()];
             private final float[] centroidCorrectiveValues = new float[3];
-            private int quantizedCentroidComponentSum;
-            private final long centroidByteSize = fieldInfo.getVectorDimension() + 3 * Float.BYTES + Short.BYTES;
+            private final long rawCentroidsOffset = (long) numCentroids * (fieldInfo.getVectorDimension() + 3 * Float.BYTES + Short.BYTES);
+            private final long rawCentroidsByteSize = (long) Float.BYTES * fieldInfo.getVectorDimension();
 
             @Override
             public int size() {
@@ -75,35 +74,67 @@ public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeap
 
             @Override
             public float[] centroid(int centroidOrdinal) throws IOException {
-                readQuantizedAndRawCentroid(centroidOrdinal);
+                if (centroidOrdinal != currentCentroid) {
+                    centroids.seek(rawCentroidsOffset + rawCentroidsByteSize * centroidOrdinal);
+                    centroids.readFloats(centroid, 0, centroid.length);
+                    currentCentroid = centroidOrdinal;
+                }
                 return centroid;
             }
 
-            private void readQuantizedAndRawCentroid(int centroidOrdinal) throws IOException {
-                if (centroidOrdinal == currentCentroid) {
-                    return;
+            public void bulkScore(NeighborQueue queue) throws IOException {
+                // TODO: bulk score centroids like we do with posting lists
+                centroids.seek(0L);
+                for (int i = 0; i < numCentroids; i++) {
+                    queue.add(i, score());
                 }
-                centroids.seek(centroidOrdinal * centroidByteSize);
-                quantizedCentroidComponentSum = readQuantizedValue(centroids, quantizedCentroid, centroidCorrectiveValues);
-                centroids.seek(numCentroids * centroidByteSize + (long) Float.BYTES * quantizedCentroid.length * centroidOrdinal);
-                centroids.readFloats(centroid, 0, centroid.length);
-                currentCentroid = centroidOrdinal;
             }
 
-            @Override
-            public float score(int centroidOrdinal) throws IOException {
-                readQuantizedAndRawCentroid(centroidOrdinal);
+            private float score() throws IOException {
+                final float qcDist = scorer.int4DotProduct(quantized);
+                centroids.readFloats(centroidCorrectiveValues, 0, 3);
+                final int quantizedCentroidComponentSum = Short.toUnsignedInt(centroids.readShort());
                 return int4QuantizedScore(
-                    quantized,
+                    qcDist,
                     queryParams,
                     fieldInfo.getVectorDimension(),
-                    quantizedCentroid,
                     centroidCorrectiveValues,
                     quantizedCentroidComponentSum,
                     globalCentroidDp,
                     fieldInfo.getVectorSimilarityFunction()
                 );
             }
+
+            // TODO can we do this in off-heap blocks?
+            private float int4QuantizedScore(
+                float qcDist,
+                OptimizedScalarQuantizer.QuantizationResult queryCorrections,
+                int dims,
+                float[] targetCorrections,
+                int targetComponentSum,
+                float centroidDp,
+                VectorSimilarityFunction similarityFunction
+            ) {
+                float ax = targetCorrections[0];
+                // Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
+                float lx = (targetCorrections[1] - ax) * FOUR_BIT_SCALE;
+                float ay = queryCorrections.lowerInterval();
+                float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
+                float y1 = queryCorrections.quantizedComponentSum();
+                float score = ax * ay * dims + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist;
+                if (similarityFunction == EUCLIDEAN) {
+                    score = queryCorrections.additionalCorrection() + targetCorrections[2] - 2 * score;
+                    return Math.max(1 / (1f + score), 0);
+                } else {
+                    // For cosine and max inner product, we need to apply the additional correction, which is
+                    // assumed to be the non-centered dot-product between the vector and the centroid
+                    score += queryCorrections.additionalCorrection() + targetCorrections[2] - centroidDp;
+                    if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
+                        return VectorUtil.scaleMaxInnerProductScore(score);
+                    }
+                    return Math.max((1f + score) / 2f, 0);
+                }
+            }
         };
     }
 
@@ -111,10 +142,7 @@ public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeap
     NeighborQueue scorePostingLists(FieldInfo fieldInfo, KnnCollector knnCollector, CentroidQueryScorer centroidQueryScorer, int nProbe)
         throws IOException {
         NeighborQueue neighborQueue = new NeighborQueue(centroidQueryScorer.size(), true);
-        // TODO Off heap scoring for quantized centroids?
-        for (int centroid = 0; centroid < centroidQueryScorer.size(); centroid++) {
-            neighborQueue.add(centroid, centroidQueryScorer.score(centroid));
-        }
+        centroidQueryScorer.bulkScore(neighborQueue);
         return neighborQueue;
     }
 
@@ -125,39 +153,6 @@ public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeap
         return new MemorySegmentPostingsVisitor(target, indexInput.clone(), entry, fieldInfo, needsScoring);
     }
 
-    // TODO can we do this in off-heap blocks?
-    static float int4QuantizedScore(
-        byte[] quantizedQuery,
-        OptimizedScalarQuantizer.QuantizationResult queryCorrections,
-        int dims,
-        byte[] binaryCode,
-        float[] targetCorrections,
-        int targetComponentSum,
-        float centroidDp,
-        VectorSimilarityFunction similarityFunction
-    ) {
-        float qcDist = VectorUtil.int4DotProduct(quantizedQuery, binaryCode);
-        float ax = targetCorrections[0];
-        // Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
-        float lx = (targetCorrections[1] - ax) * FOUR_BIT_SCALE;
-        float ay = queryCorrections.lowerInterval();
-        float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
-        float y1 = queryCorrections.quantizedComponentSum();
-        float score = ax * ay * dims + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist;
-        if (similarityFunction == EUCLIDEAN) {
-            score = queryCorrections.additionalCorrection() + targetCorrections[2] - 2 * score;
-            return Math.max(1 / (1f + score), 0);
-        } else {
-            // For cosine and max inner product, we need to apply the additional correction, which is
-            // assumed to be the non-centered dot-product between the vector and the centroid
-            score += queryCorrections.additionalCorrection() + targetCorrections[2] - centroidDp;
-            if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
-                return VectorUtil.scaleMaxInnerProductScore(score);
-            }
-            return Math.max((1f + score) / 2f, 0);
-        }
-    }
-
     @Override
     public Map<String, Long> getOffHeapByteSize(FieldInfo fieldInfo) {
         return Map.of();
@@ -356,12 +351,4 @@ public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeap
         }
     }
 
-    static int readQuantizedValue(IndexInput indexInput, byte[] binaryValue, float[] corrections) throws IOException {
-        assert corrections.length == 3;
-        indexInput.readBytes(binaryValue, 0, binaryValue.length);
-        corrections[0] = Float.intBitsToFloat(indexInput.readInt());
-        corrections[1] = Float.intBitsToFloat(indexInput.readInt());
-        corrections[2] = Float.intBitsToFloat(indexInput.readInt());
-        return Short.toUnsignedInt(indexInput.readShort());
-    }
 }

+ 1 - 1
server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java

@@ -332,7 +332,7 @@ public abstract class IVFVectorsReader extends KnnVectorsReader {
 
         float[] centroid(int centroidOrdinal) throws IOException;
 
-        float score(int centroidOrdinal) throws IOException;
+        void bulkScore(NeighborQueue queue) throws IOException;
     }
 
     interface PostingVisitor {