Browse Source

Improve halfbyte transposition performance, marginally improving bbq performance (#117350) (#118293)

The transposition of the bits in half-byte queries for BBQ is pretty
convoluted and slow. This commit greatly simplifies & improves
performance for this small part of bbq queries and indexing.

Here are the results of a small JMH benchmark for this particular
function.

```
TransposeBinBenchmark.transposeBinNew     1024  thrpt    5  857.779 ± 44.031  ops/ms
TransposeBinBenchmark.transposeBinOrig    1024  thrpt    5   94.950 ±  2.898  ops/ms
```

While this is a huge improvement for this small function, the impact at
query and index time is only marginal. But, the code simplification
itself is enough to warrant this change in my opinion.

(cherry picked from commit e90eb7ab0df06239a69a1945ca6ef5effc065433)
Benjamin Trent 10 months ago
parent
commit
67332f812f

+ 5 - 0
docs/changelog/117350.yaml

@@ -0,0 +1,5 @@
+pr: 117350
+summary: "Improve halfbyte transposition performance, marginally improving bbq performance"
+area: Vector Search
+type: enhancement
+issues: []

+ 25 - 43
server/src/main/java/org/elasticsearch/index/codec/vectors/BQSpaceUtils.java

@@ -23,56 +23,38 @@ package org.elasticsearch.index.codec.vectors;
 public class BQSpaceUtils {
 
     public static final short B_QUERY = 4;
-    // the first four bits masked
-    private static final int B_QUERY_MASK = 15;
 
     /**
      * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10
+     * Transpose the query vector into a byte array allowing for efficient bitwise operations with the
+     * index bit vectors. The idea here is to organize the query vector bits such that the first bit
+     * of every dimension is in the first set dimensions bits, or (dimensions/8) bytes. The second,
+     * third, and fourth bits are in the second, third, and fourth set of dimensions bits,
+     * respectively. This allows for direct bitwise comparisons with the stored index vectors through
+     * summing the bitwise results with the relative required bit shifts.
+     *
      * @param q the query vector, assumed to be half-byte quantized with values between 0 and 15
-     * @param dimensions the number of dimensions in the query vector
      * @param quantQueryByte the byte array to store the transposed query vector
      */
-    public static void transposeBin(byte[] q, int dimensions, byte[] quantQueryByte) {
-        // TODO: rewrite this in Panama Vector API
-        int qOffset = 0;
-        final byte[] v1 = new byte[4];
-        final byte[] v = new byte[32];
-        for (int i = 0; i < dimensions; i += 32) {
-            // for every four bytes we shift left (with remainder across those bytes)
-            for (int j = 0; j < v.length; j += 4) {
-                v[j] = (byte) (q[qOffset + j] << B_QUERY | ((q[qOffset + j] >>> B_QUERY) & B_QUERY_MASK));
-                v[j + 1] = (byte) (q[qOffset + j + 1] << B_QUERY | ((q[qOffset + j + 1] >>> B_QUERY) & B_QUERY_MASK));
-                v[j + 2] = (byte) (q[qOffset + j + 2] << B_QUERY | ((q[qOffset + j + 2] >>> B_QUERY) & B_QUERY_MASK));
-                v[j + 3] = (byte) (q[qOffset + j + 3] << B_QUERY | ((q[qOffset + j + 3] >>> B_QUERY) & B_QUERY_MASK));
-            }
-            for (int j = 0; j < B_QUERY; j++) {
-                moveMaskEpi8Byte(v, v1);
-                for (int k = 0; k < 4; k++) {
-                    quantQueryByte[(B_QUERY - j - 1) * (dimensions / 8) + i / 8 + k] = v1[k];
-                    v1[k] = 0;
-                }
-                for (int k = 0; k < v.length; k += 4) {
-                    v[k] = (byte) (v[k] + v[k]);
-                    v[k + 1] = (byte) (v[k + 1] + v[k + 1]);
-                    v[k + 2] = (byte) (v[k + 2] + v[k + 2]);
-                    v[k + 3] = (byte) (v[k + 3] + v[k + 3]);
-                }
-            }
-            qOffset += 32;
-        }
-    }
-
-    private static void moveMaskEpi8Byte(byte[] v, byte[] v1b) {
-        int m = 0;
-        for (int k = 0; k < v.length; k++) {
-            if ((v[k] & 0b10000000) == 0b10000000) {
-                v1b[m] |= 0b00000001;
-            }
-            if (k % 8 == 7) {
-                m++;
-            } else {
-                v1b[m] <<= 1;
+    public static void transposeHalfByte(byte[] q, byte[] quantQueryByte) {
+        for (int i = 0; i < q.length;) {
+            assert q[i] >= 0 && q[i] <= 15;
+            int lowerByte = 0;
+            int lowerMiddleByte = 0;
+            int upperMiddleByte = 0;
+            int upperByte = 0;
+            for (int j = 7; j >= 0 && i < q.length; j--) {
+                lowerByte |= (q[i] & 1) << j;
+                lowerMiddleByte |= ((q[i] >> 1) & 1) << j;
+                upperMiddleByte |= ((q[i] >> 2) & 1) << j;
+                upperByte |= ((q[i] >> 3) & 1) << j;
+                i++;
             }
+            int index = ((i + 7) / 8) - 1;
+            quantQueryByte[index] = (byte) lowerByte;
+            quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte;
+            quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte;
+            quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
         }
     }
 }

+ 2 - 6
server/src/main/java/org/elasticsearch/index/codec/vectors/es816/BinaryQuantizer.java

@@ -225,9 +225,7 @@ public class BinaryQuantizer {
 
         // q¯ = Δ · q¯𝑢 + 𝑣𝑙 · 1𝐷
         // q¯ is an approximation of q′ (scalar quantized approximation)
-        // FIXME: vectors need to be padded but that's expensive; update transponseBin to deal
-        byteQuery = BQVectorUtils.pad(byteQuery, discretizedDimensions);
-        BQSpaceUtils.transposeBin(byteQuery, discretizedDimensions, queryDestination);
+        BQSpaceUtils.transposeHalfByte(byteQuery, queryDestination);
         QueryFactors factors = new QueryFactors(quantResult.quantizedSum, distToC, lower, width, normVmC, vDotC);
         final float[] indexCorrections;
         if (similarityFunction == EUCLIDEAN) {
@@ -368,9 +366,7 @@ public class BinaryQuantizer {
 
         // q¯ = Δ · q¯𝑢 + 𝑣𝑙 · 1𝐷
         // q¯ is an approximation of q′ (scalar quantized approximation)
-        // FIXME: vectors need to be padded but that's expensive; update transponseBin to deal
-        byteQuery = BQVectorUtils.pad(byteQuery, discretizedDimensions);
-        BQSpaceUtils.transposeBin(byteQuery, discretizedDimensions, destination);
+        BQSpaceUtils.transposeHalfByte(byteQuery, destination);
 
         QueryFactors factors;
         if (similarityFunction != EUCLIDEAN) {