Browse Source

Speed up bit compared with floats or bytes script operations (#117199) (#117841)

Instead of doing an "if" statement, which doesn't lend itself to
vectorization, I switched to expand to the bits and multiply the 1s and
0s.

This led to a marginal speed improvement on ARM.

I expect that Panama vector could be used here to be even faster, but I
didn't want to spend anymore time on this for the time being.

```
Benchmark                                              (dims)   Mode  Cnt  Score   Error   Units
IpBitVectorScorerBenchmark.dotProductByteIfStatement      768  thrpt    5  2.952 ± 0.026  ops/us
IpBitVectorScorerBenchmark.dotProductByteUnwrap           768  thrpt    5  4.017 ± 0.068  ops/us
IpBitVectorScorerBenchmark.dotProductFloatIfStatement     768  thrpt    5  2.987 ± 0.124  ops/us
IpBitVectorScorerBenchmark.dotProductFloatUnwrap          768  thrpt    5  4.726 ± 0.136  ops/us
```

Benchmark I used.
https://gist.github.com/benwtrent/b0edb3975d2f03356c1a5ea84c72abc9
Benjamin Trent 10 months ago
parent
commit
ff18d1b6ce

+ 5 - 0
docs/changelog/117199.yaml

@@ -0,0 +1,5 @@
+pr: 117199
+summary: Speed up bit compared with floats or bytes script operations
+area: Vector Search
+type: enhancement
+issues: []

+ 2 - 21
libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java

@@ -61,17 +61,7 @@ public class ESVectorUtil {
         if (q.length != d.length * Byte.SIZE) {
             throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + Byte.SIZE + " x " + d.length);
         }
-        int result = 0;
-        // now combine the two vectors, summing the byte dimensions where the bit in d is `1`
-        for (int i = 0; i < d.length; i++) {
-            byte mask = d[i];
-            for (int j = Byte.SIZE - 1; j >= 0; j--) {
-                if ((mask & (1 << j)) != 0) {
-                    result += q[i * Byte.SIZE + Byte.SIZE - 1 - j];
-                }
-            }
-        }
-        return result;
+        return IMPL.ipByteBit(q, d);
     }
 
     /**
@@ -87,16 +77,7 @@ public class ESVectorUtil {
         if (q.length != d.length * Byte.SIZE) {
             throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + Byte.SIZE + " x " + d.length);
         }
-        float result = 0;
-        for (int i = 0; i < d.length; i++) {
-            byte mask = d[i];
-            for (int j = Byte.SIZE - 1; j >= 0; j--) {
-                if ((mask & (1 << j)) != 0) {
-                    result += q[i * Byte.SIZE + Byte.SIZE - 1 - j];
-                }
-            }
-        }
-        return result;
+        return IMPL.ipFloatBit(q, d);
     }
 
     /**

+ 65 - 0
libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java

@@ -10,9 +10,18 @@
 package org.elasticsearch.simdvec.internal.vectorization;
 
 import org.apache.lucene.util.BitUtil;
+import org.apache.lucene.util.Constants;
 
 final class DefaultESVectorUtilSupport implements ESVectorUtilSupport {
 
+    private static float fma(float a, float b, float c) {
+        if (Constants.HAS_FAST_SCALAR_FMA) {
+            return Math.fma(a, b, c);
+        } else {
+            return a * b + c;
+        }
+    }
+
     DefaultESVectorUtilSupport() {}
 
     @Override
@@ -20,6 +29,62 @@ final class DefaultESVectorUtilSupport implements ESVectorUtilSupport {
         return ipByteBinByteImpl(q, d);
     }
 
+    @Override
+    public int ipByteBit(byte[] q, byte[] d) {
+        return ipByteBitImpl(q, d);
+    }
+
+    @Override
+    public float ipFloatBit(float[] q, byte[] d) {
+        return ipFloatBitImpl(q, d);
+    }
+
+    public static int ipByteBitImpl(byte[] q, byte[] d) {
+        assert q.length == d.length * Byte.SIZE;
+        int acc0 = 0;
+        int acc1 = 0;
+        int acc2 = 0;
+        int acc3 = 0;
+        // now combine the two vectors, summing the byte dimensions where the bit in d is `1`
+        for (int i = 0; i < d.length; i++) {
+            byte mask = d[i];
+            // Make sure its just 1 or 0
+
+            acc0 += q[i * Byte.SIZE + 0] * ((mask >> 7) & 1);
+            acc1 += q[i * Byte.SIZE + 1] * ((mask >> 6) & 1);
+            acc2 += q[i * Byte.SIZE + 2] * ((mask >> 5) & 1);
+            acc3 += q[i * Byte.SIZE + 3] * ((mask >> 4) & 1);
+
+            acc0 += q[i * Byte.SIZE + 4] * ((mask >> 3) & 1);
+            acc1 += q[i * Byte.SIZE + 5] * ((mask >> 2) & 1);
+            acc2 += q[i * Byte.SIZE + 6] * ((mask >> 1) & 1);
+            acc3 += q[i * Byte.SIZE + 7] * ((mask >> 0) & 1);
+        }
+        return acc0 + acc1 + acc2 + acc3;
+    }
+
+    public static float ipFloatBitImpl(float[] q, byte[] d) {
+        assert q.length == d.length * Byte.SIZE;
+        float acc0 = 0;
+        float acc1 = 0;
+        float acc2 = 0;
+        float acc3 = 0;
+        // now combine the two vectors, summing the byte dimensions where the bit in d is `1`
+        for (int i = 0; i < d.length; i++) {
+            byte mask = d[i];
+            acc0 = fma(q[i * Byte.SIZE + 0], (mask >> 7) & 1, acc0);
+            acc1 = fma(q[i * Byte.SIZE + 1], (mask >> 6) & 1, acc1);
+            acc2 = fma(q[i * Byte.SIZE + 2], (mask >> 5) & 1, acc2);
+            acc3 = fma(q[i * Byte.SIZE + 3], (mask >> 4) & 1, acc3);
+
+            acc0 = fma(q[i * Byte.SIZE + 4], (mask >> 3) & 1, acc0);
+            acc1 = fma(q[i * Byte.SIZE + 5], (mask >> 2) & 1, acc1);
+            acc2 = fma(q[i * Byte.SIZE + 6], (mask >> 1) & 1, acc2);
+            acc3 = fma(q[i * Byte.SIZE + 7], (mask >> 0) & 1, acc3);
+        }
+        return acc0 + acc1 + acc2 + acc3;
+    }
+
     public static long ipByteBinByteImpl(byte[] q, byte[] d) {
         long ret = 0;
         int size = d.length;

+ 4 - 0
libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java

@@ -14,4 +14,8 @@ public interface ESVectorUtilSupport {
     short B_QUERY = 4;
 
     long ipByteBinByte(byte[] q, byte[] d);
+
+    int ipByteBit(byte[] q, byte[] d);
+
+    float ipFloatBit(float[] q, byte[] d);
 }

+ 10 - 0
libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java

@@ -48,6 +48,16 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
         return DefaultESVectorUtilSupport.ipByteBinByteImpl(q, d);
     }
 
+    @Override
+    public int ipByteBit(byte[] q, byte[] d) {
+        return DefaultESVectorUtilSupport.ipByteBitImpl(q, d);
+    }
+
+    @Override
+    public float ipFloatBit(float[] q, byte[] d) {
+        return DefaultESVectorUtilSupport.ipFloatBitImpl(q, d);
+    }
+
     private static final VectorSpecies<Byte> BYTE_SPECIES_128 = ByteVector.SPECIES_128;
     private static final VectorSpecies<Byte> BYTE_SPECIES_256 = ByteVector.SPECIES_256;