1
0
Эх сурвалжийг харах

Add panama implementations of byte-bit and float-bit script operations (#124722)

Simon Cooper 6 сар өмнө
parent
commit
7f1203e472

+ 31 - 16
benchmarks/src/main/java/org/elasticsearch/benchmark/vector/DistanceFunctionBenchmark.java

@@ -13,6 +13,8 @@ import org.apache.lucene.util.BytesRef;
 import org.elasticsearch.common.logging.LogConfigurator;
 import org.elasticsearch.index.IndexVersion;
 import org.elasticsearch.script.field.vectors.BinaryDenseVector;
+import org.elasticsearch.script.field.vectors.BitBinaryDenseVector;
+import org.elasticsearch.script.field.vectors.BitKnnDenseVector;
 import org.elasticsearch.script.field.vectors.ByteBinaryDenseVector;
 import org.elasticsearch.script.field.vectors.ByteKnnDenseVector;
 import org.elasticsearch.script.field.vectors.DenseVector;
@@ -37,30 +39,30 @@ import java.util.concurrent.TimeUnit;
 import java.util.function.DoubleSupplier;
 
 /**
- * Various benchmarks for the distance functions
- * used by indexed and non-indexed vectors.
- * Parameters include element, dims, function, and type.
+ * Various benchmarks for the distance functions used by indexed and non-indexed vectors.
+ * Parameters include doc and query type, dims, function, and implementation.
  * For individual local tests it may be useful to increase
- * fork, measurement, and operations per invocation. (Note
- * to also update the benchmark loop if operations per invocation
- * is increased.)
+ * fork, measurement, and operations per invocation.
  */
 @Fork(1)
 @Warmup(iterations = 1)
 @Measurement(iterations = 2)
 @BenchmarkMode(Mode.AverageTime)
 @OutputTimeUnit(TimeUnit.NANOSECONDS)
-@OperationsPerInvocation(25000)
+@OperationsPerInvocation(DistanceFunctionBenchmark.OPERATIONS)
 @State(Scope.Benchmark)
 public class DistanceFunctionBenchmark {
 
+    public static final int OPERATIONS = 25000;
+
     static {
         LogConfigurator.configureESLogging();
     }
 
     public enum VectorType {
         FLOAT,
-        BYTE
+        BYTE,
+        BIT
     }
 
     public enum Function {
@@ -122,7 +124,7 @@ public class DistanceFunctionBenchmark {
     }
 
     private static BytesRef generateVectorData(float[] vector, float mag) {
-        ByteBuffer buffer = ByteBuffer.allocate(vector.length * 4 + 4);
+        ByteBuffer buffer = ByteBuffer.allocate(vector.length * Float.BYTES + Float.BYTES);
         for (float f : vector) {
             buffer.putFloat(f);
         }
@@ -133,7 +135,7 @@ public class DistanceFunctionBenchmark {
     private static BytesRef generateVectorData(byte[] vector) {
         float mag = calculateMag(vector);
 
-        ByteBuffer buffer = ByteBuffer.allocate(vector.length + 4);
+        ByteBuffer buffer = ByteBuffer.allocate(vector.length + Float.BYTES);
         buffer.put(vector);
         buffer.putFloat(mag);
         return new BytesRef(buffer.array());
@@ -141,16 +143,21 @@ public class DistanceFunctionBenchmark {
 
     @Setup
     public void findBenchmarkImpl() {
+        if (dims % 8 != 0) throw new IllegalArgumentException("Dims must be a multiple of 8");
         Random r = new Random();
 
         float[] floatDocVector = new float[dims];
         byte[] byteDocVector = new byte[dims];
+        byte[] bitDocVector = new byte[dims / 8];
 
         float[] floatQueryVector = new float[dims];
         byte[] byteQueryVector = new byte[dims];
+        byte[] bitQueryVector = new byte[dims / 8];
 
         r.nextBytes(byteDocVector);
+        r.nextBytes(bitDocVector);
         r.nextBytes(byteQueryVector);
+        r.nextBytes(bitQueryVector);
         for (int i = 0; i < dims; i++) {
             floatDocVector[i] = r.nextFloat();
             floatQueryVector[i] = r.nextFloat();
@@ -179,10 +186,11 @@ public class DistanceFunctionBenchmark {
             };
             case BYTE -> switch (type) {
                 case KNN -> new ByteKnnDenseVector(byteDocVector);
-                case BINARY -> {
-                    BytesRef vectorData = generateVectorData(byteDocVector);
-                    yield new ByteBinaryDenseVector(byteDocVector, vectorData, dims);
-                }
+                case BINARY -> new ByteBinaryDenseVector(byteDocVector, generateVectorData(byteDocVector), dims);
+            };
+            case BIT -> switch (type) {
+                case KNN -> new BitKnnDenseVector(bitDocVector);
+                case BINARY -> new BitBinaryDenseVector(bitDocVector, new BytesRef(bitDocVector), bitDocVector.length);
             };
         };
 
@@ -204,13 +212,20 @@ public class DistanceFunctionBenchmark {
                 case L2 -> () -> vectorImpl.l2Norm(byteQueryVector);
                 case HAMMING -> () -> vectorImpl.hamming(byteQueryVector);
             };
+            case BIT -> switch (function) {
+                case DOT -> () -> vectorImpl.dotProduct(bitQueryVector);
+                case COSINE -> throw new UnsupportedOperationException("Unsupported function " + function);
+                case L1 -> () -> vectorImpl.l1Norm(bitQueryVector);
+                case L2 -> () -> vectorImpl.l2Norm(bitQueryVector);
+                case HAMMING -> () -> vectorImpl.hamming(bitQueryVector);
+            };
         };
     }
 
     @Fork(1)
     @Benchmark
     public void benchmark(Blackhole blackhole) {
-        for (int i = 0; i < 25000; ++i) {
+        for (int i = 0; i < OPERATIONS; ++i) {
             blackhole.consume(benchmarkImpl.getAsDouble());
         }
     }
@@ -218,7 +233,7 @@ public class DistanceFunctionBenchmark {
     @Fork(value = 1, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
     @Benchmark
     public void vectorBenchmark(Blackhole blackhole) {
-        for (int i = 0; i < 25000; ++i) {
+        for (int i = 0; i < OPERATIONS; ++i) {
             blackhole.consume(benchmarkImpl.getAsDouble());
         }
     }

+ 6 - 0
docs/changelog/124722.yaml

@@ -0,0 +1,6 @@
+pr: 124722
+summary: Add panama implementations of byte-bit and float-bit script operations
+area: Vector Search
+type: enhancement
+issues:
+ - 117096

+ 10 - 2
libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java

@@ -45,13 +45,17 @@ final class DefaultESVectorUtilSupport implements ESVectorUtilSupport {
     }
 
     public static int ipByteBitImpl(byte[] q, byte[] d) {
+        return ipByteBitImpl(q, d, 0);
+    }
+
+    public static int ipByteBitImpl(byte[] q, byte[] d, int start) {
         assert q.length == d.length * Byte.SIZE;
         int acc0 = 0;
         int acc1 = 0;
         int acc2 = 0;
         int acc3 = 0;
         // now combine the two vectors, summing the byte dimensions where the bit in d is `1`
-        for (int i = 0; i < d.length; i++) {
+        for (int i = start; i < d.length; i++) {
             byte mask = d[i];
             // Make sure its just 1 or 0
 
@@ -69,13 +73,17 @@ final class DefaultESVectorUtilSupport implements ESVectorUtilSupport {
     }
 
     public static float ipFloatBitImpl(float[] q, byte[] d) {
+        return ipFloatBitImpl(q, d, 0);
+    }
+
+    static float ipFloatBitImpl(float[] q, byte[] d, int start) {
         assert q.length == d.length * Byte.SIZE;
         float acc0 = 0;
         float acc1 = 0;
         float acc2 = 0;
         float acc3 = 0;
         // now combine the two vectors, summing the byte dimensions where the bit in d is `1`
-        for (int i = 0; i < d.length; i++) {
+        for (int i = start; i < d.length; i++) {
             byte mask = d[i];
             acc0 = fma(q[i * Byte.SIZE + 0], (mask >> 7) & 1, acc0);
             acc1 = fma(q[i * Byte.SIZE + 1], (mask >> 6) & 1, acc1);

+ 251 - 1
libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java

@@ -13,10 +13,12 @@ import jdk.incubator.vector.ByteVector;
 import jdk.incubator.vector.FloatVector;
 import jdk.incubator.vector.IntVector;
 import jdk.incubator.vector.LongVector;
+import jdk.incubator.vector.VectorMask;
 import jdk.incubator.vector.VectorOperators;
 import jdk.incubator.vector.VectorShape;
 import jdk.incubator.vector.VectorSpecies;
 
+import org.apache.lucene.util.BitUtil;
 import org.apache.lucene.util.Constants;
 
 public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
@@ -51,11 +53,25 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
 
     @Override
     public int ipByteBit(byte[] q, byte[] d) {
+        if (d.length >= 16 && HAS_FAST_INTEGER_VECTORS) {
+            if (VECTOR_BITSIZE >= 512) {
+                return ipByteBit512(q, d);
+            } else if (VECTOR_BITSIZE == 256) {
+                return ipByteBit256(q, d);
+            }
+        }
         return DefaultESVectorUtilSupport.ipByteBitImpl(q, d);
     }
 
     @Override
     public float ipFloatBit(float[] q, byte[] d) {
+        if (q.length >= 16) {
+            if (VECTOR_BITSIZE >= 512) {
+                return ipFloatBit512(q, d);
+            } else if (VECTOR_BITSIZE == 256) {
+                return ipFloatBit256(q, d);
+            }
+        }
         return DefaultESVectorUtilSupport.ipFloatBitImpl(q, d);
     }
 
@@ -170,6 +186,240 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
         return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
     }
 
+    private static final VectorSpecies<Integer> INT_SPECIES_512 = IntVector.SPECIES_512;
+    private static final VectorSpecies<Byte> BYTE_SPECIES_FOR_INT_512 = VectorSpecies.of(
+        byte.class,
+        VectorShape.forBitSize(INT_SPECIES_512.vectorBitSize() / Integer.BYTES)
+    );
+    private static final VectorSpecies<Integer> INT_SPECIES_256 = IntVector.SPECIES_256;
+    private static final VectorSpecies<Byte> BYTE_SPECIES_FOR_INT_256 = VectorSpecies.of(
+        byte.class,
+        VectorShape.forBitSize(INT_SPECIES_256.vectorBitSize() / Integer.BYTES)
+    );
+
+    private static int limit(int length, int sectionSize) {
+        return length - (length % sectionSize);
+    }
+
+    static int ipByteBit512(byte[] q, byte[] d) {
+        assert q.length == d.length * Byte.SIZE;
+        int i = 0;
+        int sum = 0;
+
+        int sectionLength = INT_SPECIES_512.length() * 4;
+        if (q.length >= sectionLength) {
+            IntVector acc0 = IntVector.zero(INT_SPECIES_512);
+            IntVector acc1 = IntVector.zero(INT_SPECIES_512);
+            IntVector acc2 = IntVector.zero(INT_SPECIES_512);
+            IntVector acc3 = IntVector.zero(INT_SPECIES_512);
+            int limit = limit(q.length, sectionLength);
+            for (; i < limit; i += sectionLength) {
+                var vals0 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_512, q, i).castShape(INT_SPECIES_512, 0);
+                var vals1 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_512, q, i + INT_SPECIES_512.length()).castShape(INT_SPECIES_512, 0);
+                var vals2 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_512, q, i + INT_SPECIES_512.length() * 2)
+                    .castShape(INT_SPECIES_512, 0);
+                var vals3 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_512, q, i + INT_SPECIES_512.length() * 3)
+                    .castShape(INT_SPECIES_512, 0);
+
+                long maskBits = Long.reverse((long) BitUtil.VH_BE_LONG.get(d, i / 8));
+                var mask0 = VectorMask.fromLong(INT_SPECIES_512, maskBits);
+                var mask1 = VectorMask.fromLong(INT_SPECIES_512, maskBits >> 16);
+                var mask2 = VectorMask.fromLong(INT_SPECIES_512, maskBits >> 32);
+                var mask3 = VectorMask.fromLong(INT_SPECIES_512, maskBits >> 48);
+
+                acc0 = acc0.add(vals0, mask0);
+                acc1 = acc1.add(vals1, mask1);
+                acc2 = acc2.add(vals2, mask2);
+                acc3 = acc3.add(vals3, mask3);
+            }
+            sum += acc0.reduceLanes(VectorOperators.ADD) + acc1.reduceLanes(VectorOperators.ADD) + acc2.reduceLanes(VectorOperators.ADD)
+                + acc3.reduceLanes(VectorOperators.ADD);
+        }
+
+        sectionLength = INT_SPECIES_256.length();
+        if (q.length - i >= sectionLength) {
+            IntVector acc = IntVector.zero(INT_SPECIES_256);
+            int limit = limit(q.length, sectionLength);
+            for (; i < limit; i += sectionLength) {
+                var vals = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i).castShape(INT_SPECIES_256, 0);
+
+                long maskBits = Integer.reverse(d[i / 8]) >> 24;
+                var mask = VectorMask.fromLong(INT_SPECIES_256, maskBits);
+
+                acc = acc.add(vals, mask);
+            }
+            sum += acc.reduceLanes(VectorOperators.ADD);
+        }
+
+        // that should have got them all (q.length is a multiple of 8, which fits in a 256-bit vector)
+        assert i == q.length;
+        return sum;
+    }
+
+    static int ipByteBit256(byte[] q, byte[] d) {
+        assert q.length == d.length * Byte.SIZE;
+        int i = 0;
+        int sum = 0;
+
+        int sectionLength = INT_SPECIES_256.length() * 4;
+        if (q.length >= sectionLength) {
+            IntVector acc0 = IntVector.zero(INT_SPECIES_256);
+            IntVector acc1 = IntVector.zero(INT_SPECIES_256);
+            IntVector acc2 = IntVector.zero(INT_SPECIES_256);
+            IntVector acc3 = IntVector.zero(INT_SPECIES_256);
+            int limit = limit(q.length, sectionLength);
+            for (; i < limit; i += sectionLength) {
+                var vals0 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i).castShape(INT_SPECIES_256, 0);
+                var vals1 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i + INT_SPECIES_256.length()).castShape(INT_SPECIES_256, 0);
+                var vals2 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i + INT_SPECIES_256.length() * 2)
+                    .castShape(INT_SPECIES_256, 0);
+                var vals3 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i + INT_SPECIES_256.length() * 3)
+                    .castShape(INT_SPECIES_256, 0);
+
+                long maskBits = Integer.reverse((int) BitUtil.VH_BE_INT.get(d, i / 8));
+                var mask0 = VectorMask.fromLong(INT_SPECIES_256, maskBits);
+                var mask1 = VectorMask.fromLong(INT_SPECIES_256, maskBits >> 8);
+                var mask2 = VectorMask.fromLong(INT_SPECIES_256, maskBits >> 16);
+                var mask3 = VectorMask.fromLong(INT_SPECIES_256, maskBits >> 24);
+
+                acc0 = acc0.add(vals0, mask0);
+                acc1 = acc1.add(vals1, mask1);
+                acc2 = acc2.add(vals2, mask2);
+                acc3 = acc3.add(vals3, mask3);
+            }
+            sum += acc0.reduceLanes(VectorOperators.ADD) + acc1.reduceLanes(VectorOperators.ADD) + acc2.reduceLanes(VectorOperators.ADD)
+                + acc3.reduceLanes(VectorOperators.ADD);
+        }
+
+        sectionLength = INT_SPECIES_256.length();
+        if (q.length - i >= sectionLength) {
+            IntVector acc = IntVector.zero(INT_SPECIES_256);
+            int limit = limit(q.length, sectionLength);
+            for (; i < limit; i += sectionLength) {
+                var vals = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i).castShape(INT_SPECIES_256, 0);
+
+                long maskBits = Integer.reverse(d[i / 8]) >> 24;
+                var mask = VectorMask.fromLong(INT_SPECIES_256, maskBits);
+
+                acc = acc.add(vals, mask);
+            }
+            sum += acc.reduceLanes(VectorOperators.ADD);
+        }
+
+        // that should have got them all (q.length is a multiple of 8, which fits in a 256-bit vector)
+        assert i == q.length;
+        return sum;
+    }
+
+    private static final VectorSpecies<Float> FLOAT_SPECIES_512 = FloatVector.SPECIES_512;
+    private static final VectorSpecies<Float> FLOAT_SPECIES_256 = FloatVector.SPECIES_256;
+
+    static float ipFloatBit512(float[] q, byte[] d) {
+        assert q.length == d.length * Byte.SIZE;
+        int i = 0;
+        float sum = 0;
+
+        int sectionLength = FLOAT_SPECIES_512.length() * 4;
+        if (q.length >= sectionLength) {
+            FloatVector acc0 = FloatVector.zero(FLOAT_SPECIES_512);
+            FloatVector acc1 = FloatVector.zero(FLOAT_SPECIES_512);
+            FloatVector acc2 = FloatVector.zero(FLOAT_SPECIES_512);
+            FloatVector acc3 = FloatVector.zero(FLOAT_SPECIES_512);
+            int limit = limit(q.length, sectionLength);
+            for (; i < limit; i += sectionLength) {
+                var floats0 = FloatVector.fromArray(FLOAT_SPECIES_512, q, i);
+                var floats1 = FloatVector.fromArray(FLOAT_SPECIES_512, q, i + FLOAT_SPECIES_512.length());
+                var floats2 = FloatVector.fromArray(FLOAT_SPECIES_512, q, i + FLOAT_SPECIES_512.length() * 2);
+                var floats3 = FloatVector.fromArray(FLOAT_SPECIES_512, q, i + FLOAT_SPECIES_512.length() * 3);
+
+                long maskBits = Long.reverse((long) BitUtil.VH_BE_LONG.get(d, i / 8));
+                var mask0 = VectorMask.fromLong(FLOAT_SPECIES_512, maskBits);
+                var mask1 = VectorMask.fromLong(FLOAT_SPECIES_512, maskBits >> 16);
+                var mask2 = VectorMask.fromLong(FLOAT_SPECIES_512, maskBits >> 32);
+                var mask3 = VectorMask.fromLong(FLOAT_SPECIES_512, maskBits >> 48);
+
+                acc0 = acc0.add(floats0, mask0);
+                acc1 = acc1.add(floats1, mask1);
+                acc2 = acc2.add(floats2, mask2);
+                acc3 = acc3.add(floats3, mask3);
+            }
+            sum += acc0.reduceLanes(VectorOperators.ADD) + acc1.reduceLanes(VectorOperators.ADD) + acc2.reduceLanes(VectorOperators.ADD)
+                + acc3.reduceLanes(VectorOperators.ADD);
+        }
+
+        sectionLength = FLOAT_SPECIES_256.length();
+        if (q.length - i >= sectionLength) {
+            FloatVector acc = FloatVector.zero(FLOAT_SPECIES_256);
+            int limit = limit(q.length, sectionLength);
+            for (; i < limit; i += sectionLength) {
+                var floats = FloatVector.fromArray(FLOAT_SPECIES_256, q, i);
+
+                long maskBits = Integer.reverse(d[i / 8]) >> 24;
+                var mask = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits);
+
+                acc = acc.add(floats, mask);
+            }
+            sum += acc.reduceLanes(VectorOperators.ADD);
+        }
+
+        // that should have got them all (q.length is a multiple of 8, which fits in a 256-bit vector)
+        assert i == q.length;
+        return sum;
+    }
+
+    static float ipFloatBit256(float[] q, byte[] d) {
+        assert q.length == d.length * Byte.SIZE;
+        int i = 0;
+        float sum = 0;
+
+        int sectionLength = FLOAT_SPECIES_256.length() * 4;
+        if (q.length >= sectionLength) {
+            FloatVector acc0 = FloatVector.zero(FLOAT_SPECIES_256);
+            FloatVector acc1 = FloatVector.zero(FLOAT_SPECIES_256);
+            FloatVector acc2 = FloatVector.zero(FLOAT_SPECIES_256);
+            FloatVector acc3 = FloatVector.zero(FLOAT_SPECIES_256);
+            int limit = limit(q.length, sectionLength);
+            for (; i < limit; i += sectionLength) {
+                var floats0 = FloatVector.fromArray(FLOAT_SPECIES_256, q, i);
+                var floats1 = FloatVector.fromArray(FLOAT_SPECIES_256, q, i + FLOAT_SPECIES_256.length());
+                var floats2 = FloatVector.fromArray(FLOAT_SPECIES_256, q, i + FLOAT_SPECIES_256.length() * 2);
+                var floats3 = FloatVector.fromArray(FLOAT_SPECIES_256, q, i + FLOAT_SPECIES_256.length() * 3);
+
+                long maskBits = Integer.reverse((int) BitUtil.VH_BE_INT.get(d, i / 8));
+                var mask0 = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits);
+                var mask1 = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits >> 8);
+                var mask2 = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits >> 16);
+                var mask3 = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits >> 24);
+
+                acc0 = acc0.add(floats0, mask0);
+                acc1 = acc1.add(floats1, mask1);
+                acc2 = acc2.add(floats2, mask2);
+                acc3 = acc3.add(floats3, mask3);
+            }
+            sum += acc0.reduceLanes(VectorOperators.ADD) + acc1.reduceLanes(VectorOperators.ADD) + acc2.reduceLanes(VectorOperators.ADD)
+                + acc3.reduceLanes(VectorOperators.ADD);
+        }
+
+        sectionLength = FLOAT_SPECIES_256.length();
+        if (q.length - i >= sectionLength) {
+            FloatVector acc = FloatVector.zero(FLOAT_SPECIES_256);
+            int limit = limit(q.length, sectionLength);
+            for (; i < limit; i += sectionLength) {
+                var floats = FloatVector.fromArray(FLOAT_SPECIES_256, q, i);
+
+                long maskBits = Integer.reverse(d[i / 8]) >> 24;
+                var mask = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits);
+
+                acc = acc.add(floats, mask);
+            }
+            sum += acc.reduceLanes(VectorOperators.ADD);
+        }
+
+        // that should have got them all (q.length is a multiple of 8, which fits in a 256-bit vector)
+        assert i == q.length;
+        return sum;
+    }
+
     private static final VectorSpecies<Float> PREFERRED_FLOAT_SPECIES = FloatVector.SPECIES_PREFERRED;
     private static final VectorSpecies<Byte> BYTE_SPECIES_FOR_PREFFERED_FLOATS;
 
@@ -177,7 +427,7 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
         VectorSpecies<Byte> byteForFloat;
         try {
             // calculate vector size to convert from single bytes to 4-byte floats
-            byteForFloat = VectorSpecies.of(byte.class, VectorShape.forBitSize(PREFERRED_FLOAT_SPECIES.vectorBitSize() / Integer.BYTES));
+            byteForFloat = VectorSpecies.of(byte.class, VectorShape.forBitSize(PREFERRED_FLOAT_SPECIES.vectorBitSize() / Float.BYTES));
         } catch (IllegalArgumentException e) {
             // can't get a byte vector size small enough, just use default impl
             byteForFloat = null;

+ 31 - 16
libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java

@@ -13,7 +13,6 @@ import org.elasticsearch.simdvec.internal.vectorization.BaseVectorizationTests;
 import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
 
 import java.util.Arrays;
-import java.util.function.ToDoubleBiFunction;
 import java.util.function.ToLongBiFunction;
 
 import static org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport.B_QUERY;
@@ -25,30 +24,44 @@ public class ESVectorUtilTests extends BaseVectorizationTests {
     static final ESVectorizationProvider defOrPanamaProvider = BaseVectorizationTests.maybePanamaProvider();
 
     public void testIpByteBit() {
-        byte[] q = new byte[16];
-        byte[] d = new byte[] { (byte) Integer.parseInt("01100010", 2), (byte) Integer.parseInt("10100111", 2) };
+        byte[] d = new byte[random().nextInt(128)];
+        byte[] q = new byte[d.length * 8];
+        random().nextBytes(d);
         random().nextBytes(q);
-        int expected = q[1] + q[2] + q[6] + q[8] + q[10] + q[13] + q[14] + q[15];
-        assertEquals(expected, ESVectorUtil.ipByteBit(q, d));
+
+        int sum = 0;
+        for (int i = 0; i < q.length; i++) {
+            if (((d[i / 8] << (i % 8)) & 0x80) == 0x80) {
+                sum += q[i];
+            }
+        }
+
+        assertEquals(sum, ESVectorUtil.ipByteBit(q, d));
+        assertEquals(sum, defaultedProvider.getVectorUtilSupport().ipByteBit(q, d));
+        assertEquals(sum, defOrPanamaProvider.getVectorUtilSupport().ipByteBit(q, d));
     }
 
     public void testIpFloatBit() {
-        float[] q = new float[16];
-        byte[] d = new byte[] { (byte) Integer.parseInt("01100010", 2), (byte) Integer.parseInt("10100111", 2) };
+        byte[] d = new byte[random().nextInt(128)];
+        float[] q = new float[d.length * 8];
+        random().nextBytes(d);
+
+        float sum = 0;
         for (int i = 0; i < q.length; i++) {
             q[i] = random().nextFloat();
+            if (((d[i / 8] << (i % 8)) & 0x80) == 0x80) {
+                sum += q[i];
+            }
         }
-        float expected = q[1] + q[2] + q[6] + q[8] + q[10] + q[13] + q[14] + q[15];
-        assertEquals(expected, ESVectorUtil.ipFloatBit(q, d), 1e-6);
-    }
 
-    public void testIpFloatByte() {
-        testIpFloatByteImpl(ESVectorUtil::ipFloatByte);
-        testIpFloatByteImpl(defaultedProvider.getVectorUtilSupport()::ipFloatByte);
-        testIpFloatByteImpl(defOrPanamaProvider.getVectorUtilSupport()::ipFloatByte);
+        double delta = 1e-5 * q.length;
+
+        assertEquals(sum, ESVectorUtil.ipFloatBit(q, d), delta);
+        assertEquals(sum, defaultedProvider.getVectorUtilSupport().ipFloatBit(q, d), delta);
+        assertEquals(sum, defOrPanamaProvider.getVectorUtilSupport().ipFloatBit(q, d), delta);
     }
 
-    private void testIpFloatByteImpl(ToDoubleBiFunction<float[], byte[]> impl) {
+    public void testIpFloatByte() {
         int vectorSize = randomIntBetween(1, 1024);
         // scale the delta according to the vector size
         double delta = 1e-5 * vectorSize;
@@ -64,7 +77,9 @@ public class ESVectorUtilTests extends BaseVectorizationTests {
         for (int i = 0; i < q.length; i++) {
             expected += q[i] * d[i];
         }
-        assertThat(impl.applyAsDouble(q, d), closeTo(expected, delta));
+        assertThat((double) ESVectorUtil.ipFloatByte(q, d), closeTo(expected, delta));
+        assertThat((double) defaultedProvider.getVectorUtilSupport().ipFloatByte(q, d), closeTo(expected, delta));
+        assertThat((double) defOrPanamaProvider.getVectorUtilSupport().ipFloatByte(q, d), closeTo(expected, delta));
     }
 
     public void testBitAndCount() {