Преглед изворни кода

Vectorize BQSpaceUtils#transposeHalfByte (#132935)

Ignacio Vera пре 2 месеци
родитељ
комит
8c01b6706b

+ 9 - 0
benchmarks/src/main/java/org/elasticsearch/benchmark/vector/TransposeHalfByteBenchmark.java

@@ -83,4 +83,13 @@ public class TransposeHalfByteBenchmark {
             bh.consume(packed);
         }
     }
+
+    @Benchmark
+    @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
+    public void transposeHalfBytePanama(Blackhole bh) {
+        for (int i = 0; i < numVectors; i++) {
+            BQSpaceUtils.transposeHalfByte(qVectors[i], packed);
+            bh.consume(packed);
+        }
+    }
 }

+ 18 - 0
libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java

@@ -381,4 +381,22 @@ public class ESVectorUtil {
         }
         IMPL.packAsBinary(vector, packed);
     }
+
+    /**
+     * The idea here is to organize the query vector bits such that the first bit
+     * of every dimension is in the first set dimensions bits, or (dimensions/8) bytes. The second,
+     * third, and fourth bits are in the second, third, and fourth set of dimensions bits,
+     * respectively. This allows for direct bitwise comparisons with the stored index vectors through
+     * summing the bitwise results with the relative required bit shifts.
+     *
+     * @param q the query vector, assumed to be half-byte quantized with values between 0 and 15
+     * @param quantQueryByte the byte array to store the transposed query vector.
+     *
+     **/
+    public static void transposeHalfByte(int[] q, byte[] quantQueryByte) {
+        if (quantQueryByte.length * Byte.SIZE < 4 * q.length) {
+            throw new IllegalArgumentException("packed array is too small: " + quantQueryByte.length * Byte.SIZE + " < " + 4 * q.length);
+        }
+        IMPL.transposeHalfByte(q, quantQueryByte);
+    }
 }

+ 50 - 0
libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java

@@ -353,4 +353,54 @@ final class DefaultESVectorUtilSupport implements ESVectorUtilSupport {
         }
         packed[index] = result;
     }
+
+    @Override
+    public void transposeHalfByte(int[] q, byte[] quantQueryByte) {
+        transposeHalfByteImpl(q, quantQueryByte);
+    }
+
+    public static void transposeHalfByteImpl(int[] q, byte[] quantQueryByte) {
+        int limit = q.length - 7;
+        int i = 0;
+        int index = 0;
+        for (; i < limit; i += 8, index++) {
+            assert q[i] >= 0 && q[i] <= 15;
+            assert q[i + 1] >= 0 && q[i + 1] <= 15;
+            assert q[i + 2] >= 0 && q[i + 2] <= 15;
+            assert q[i + 3] >= 0 && q[i + 3] <= 15;
+            assert q[i + 4] >= 0 && q[i + 4] <= 15;
+            assert q[i + 5] >= 0 && q[i + 5] <= 15;
+            assert q[i + 6] >= 0 && q[i + 6] <= 15;
+            assert q[i + 7] >= 0 && q[i + 7] <= 15;
+            int lowerByte = (q[i] & 1) << 7 | (q[i + 1] & 1) << 6 | (q[i + 2] & 1) << 5 | (q[i + 3] & 1) << 4 | (q[i + 4] & 1) << 3 | (q[i
+                + 5] & 1) << 2 | (q[i + 6] & 1) << 1 | (q[i + 7] & 1);
+            int lowerMiddleByte = ((q[i] >> 1) & 1) << 7 | ((q[i + 1] >> 1) & 1) << 6 | ((q[i + 2] >> 1) & 1) << 5 | ((q[i + 3] >> 1) & 1)
+                << 4 | ((q[i + 4] >> 1) & 1) << 3 | ((q[i + 5] >> 1) & 1) << 2 | ((q[i + 6] >> 1) & 1) << 1 | ((q[i + 7] >> 1) & 1);
+            int upperMiddleByte = ((q[i] >> 2) & 1) << 7 | ((q[i + 1] >> 2) & 1) << 6 | ((q[i + 2] >> 2) & 1) << 5 | ((q[i + 3] >> 2) & 1)
+                << 4 | ((q[i + 4] >> 2) & 1) << 3 | ((q[i + 5] >> 2) & 1) << 2 | ((q[i + 6] >> 2) & 1) << 1 | ((q[i + 7] >> 2) & 1);
+            int upperByte = ((q[i] >> 3) & 1) << 7 | ((q[i + 1] >> 3) & 1) << 6 | ((q[i + 2] >> 3) & 1) << 5 | ((q[i + 3] >> 3) & 1) << 4
+                | ((q[i + 4] >> 3) & 1) << 3 | ((q[i + 5] >> 3) & 1) << 2 | ((q[i + 6] >> 3) & 1) << 1 | ((q[i + 7] >> 3) & 1);
+            quantQueryByte[index] = (byte) lowerByte;
+            quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte;
+            quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte;
+            quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
+        }
+        if (i == q.length) {
+            return; // all done
+        }
+        int lowerByte = 0;
+        int lowerMiddleByte = 0;
+        int upperMiddleByte = 0;
+        int upperByte = 0;
+        for (int j = 7; i < q.length; j--, i++) {
+            lowerByte |= (q[i] & 1) << j;
+            lowerMiddleByte |= ((q[i] >> 1) & 1) << j;
+            upperMiddleByte |= ((q[i] >> 2) & 1) << j;
+            upperByte |= ((q[i] >> 3) & 1) << j;
+        }
+        quantQueryByte[index] = (byte) lowerByte;
+        quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte;
+        quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte;
+        quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
+    }
 }

+ 2 - 0
libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java

@@ -65,4 +65,6 @@ public interface ESVectorUtilSupport {
     );
 
     void packAsBinary(int[] vector, byte[] packed);
+
+    void transposeHalfByte(int[] q, byte[] quantQueryByte);
 }

+ 101 - 0
libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java

@@ -22,6 +22,7 @@ import org.apache.lucene.util.BitUtil;
 import org.apache.lucene.util.Constants;
 
 import static jdk.incubator.vector.VectorOperators.ADD;
+import static jdk.incubator.vector.VectorOperators.ASHR;
 import static jdk.incubator.vector.VectorOperators.LSHL;
 import static jdk.incubator.vector.VectorOperators.MAX;
 import static jdk.incubator.vector.VectorOperators.MIN;
@@ -1021,4 +1022,104 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
         }
         packed[index] = result;
     }
+
+    @Override
+    public void transposeHalfByte(int[] q, byte[] quantQueryByte) {
+        // 128 / 32 == 4
+        if (q.length >= 8 && HAS_FAST_INTEGER_VECTORS) {
+            if (VECTOR_BITSIZE >= 256) {
+                transposeHalfByte256(q, quantQueryByte);
+                return;
+            } else if (VECTOR_BITSIZE == 128) {
+                transposeHalfByte128(q, quantQueryByte);
+                return;
+            }
+        }
+        DefaultESVectorUtilSupport.transposeHalfByteImpl(q, quantQueryByte);
+    }
+
+    private void transposeHalfByte256(int[] q, byte[] quantQueryByte) {
+        final int limit = INT_SPECIES_256.loopBound(q.length);
+        int i = 0;
+        int index = 0;
+        for (; i < limit; i += INT_SPECIES_256.length(), index++) {
+            IntVector v = IntVector.fromArray(INT_SPECIES_256, q, i);
+
+            int lowerByte = v.and(1).lanewise(LSHL, SHIFTS_256).reduceLanes(VectorOperators.OR);
+            int lowerMiddleByte = v.lanewise(ASHR, 1).and(1).lanewise(LSHL, SHIFTS_256).reduceLanes(VectorOperators.OR);
+            int upperMiddleByte = v.lanewise(ASHR, 2).and(1).lanewise(LSHL, SHIFTS_256).reduceLanes(VectorOperators.OR);
+            int upperByte = v.lanewise(ASHR, 3).and(1).lanewise(LSHL, SHIFTS_256).reduceLanes(VectorOperators.OR);
+
+            quantQueryByte[index] = (byte) lowerByte;
+            quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte;
+            quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte;
+            quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
+
+        }
+        if (i == q.length) {
+            return; // all done
+        }
+        int lowerByte = 0;
+        int lowerMiddleByte = 0;
+        int upperMiddleByte = 0;
+        int upperByte = 0;
+        for (int j = 7; i < q.length; j--, i++) {
+            lowerByte |= (q[i] & 1) << j;
+            lowerMiddleByte |= ((q[i] >> 1) & 1) << j;
+            upperMiddleByte |= ((q[i] >> 2) & 1) << j;
+            upperByte |= ((q[i] >> 3) & 1) << j;
+        }
+        quantQueryByte[index] = (byte) lowerByte;
+        quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte;
+        quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte;
+        quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
+    }
+
+    private void transposeHalfByte128(int[] q, byte[] quantQueryByte) {
+        final int limit = INT_SPECIES_128.loopBound(q.length) - INT_SPECIES_128.length();
+        int i = 0;
+        int index = 0;
+        for (; i < limit; i += 2 * INT_SPECIES_128.length(), index++) {
+            IntVector v = IntVector.fromArray(INT_SPECIES_128, q, i);
+
+            var lowerByteHigh = v.and(1).lanewise(LSHL, HIGH_SHIFTS_128);
+            var lowerMiddleByteHigh = v.lanewise(ASHR, 1).and(1).lanewise(LSHL, HIGH_SHIFTS_128);
+            var upperMiddleByteHigh = v.lanewise(ASHR, 2).and(1).lanewise(LSHL, HIGH_SHIFTS_128);
+            var upperByteHigh = v.lanewise(ASHR, 3).and(1).lanewise(LSHL, HIGH_SHIFTS_128);
+
+            v = IntVector.fromArray(INT_SPECIES_128, q, i + INT_SPECIES_128.length());
+            var lowerByteLow = v.and(1).lanewise(LSHL, LOW_SHIFTS_128);
+            var lowerMiddleByteLow = v.lanewise(ASHR, 1).and(1).lanewise(LSHL, LOW_SHIFTS_128);
+            var upperMiddleByteLow = v.lanewise(ASHR, 2).and(1).lanewise(LSHL, LOW_SHIFTS_128);
+            var upperByteLow = v.lanewise(ASHR, 3).and(1).lanewise(LSHL, LOW_SHIFTS_128);
+
+            int lowerByte = lowerByteHigh.lanewise(OR, lowerByteLow).reduceLanes(OR);
+            int lowerMiddleByte = lowerMiddleByteHigh.lanewise(OR, lowerMiddleByteLow).reduceLanes(OR);
+            int upperMiddleByte = upperMiddleByteHigh.lanewise(OR, upperMiddleByteLow).reduceLanes(OR);
+            int upperByte = upperByteHigh.lanewise(OR, upperByteLow).reduceLanes(OR);
+
+            quantQueryByte[index] = (byte) lowerByte;
+            quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte;
+            quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte;
+            quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
+
+        }
+        if (i == q.length) {
+            return; // all done
+        }
+        int lowerByte = 0;
+        int lowerMiddleByte = 0;
+        int upperMiddleByte = 0;
+        int upperByte = 0;
+        for (int j = 7; i < q.length; j--, i++) {
+            lowerByte |= (q[i] & 1) << j;
+            lowerMiddleByte |= ((q[i] >> 1) & 1) << j;
+            upperMiddleByte |= ((q[i] >> 2) & 1) << j;
+            upperByte |= ((q[i] >> 3) & 1) << j;
+        }
+        quantQueryByte[index] = (byte) lowerByte;
+        quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte;
+        quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte;
+        quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
+    }
 }

+ 14 - 0
libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java

@@ -370,6 +370,20 @@ public class ESVectorUtilTests extends BaseVectorizationTests {
         assertArrayEquals(packedLegacy, packed);
     }
 
+    public void testTransposeHalfByte() {
+        int dims = randomIntBetween(16, 2048);
+        int[] toPack = new int[dims];
+        for (int i = 0; i < dims; i++) {
+            toPack[i] = randomInt(15);
+        }
+        int length = 4 * BQVectorUtils.discretize(dims, 64) / 8;
+        byte[] packed = new byte[length];
+        byte[] packedLegacy = new byte[length];
+        defaultedProvider.getVectorUtilSupport().transposeHalfByte(toPack, packedLegacy);
+        defOrPanamaProvider.getVectorUtilSupport().transposeHalfByte(toPack, packed);
+        assertArrayEquals(packedLegacy, packed);
+    }
+
     private float[] generateRandomVector(int size) {
         float[] vector = new float[size];
         for (int i = 0; i < size; ++i) {

+ 3 - 42
server/src/main/java/org/elasticsearch/index/codec/vectors/BQSpaceUtils.java

@@ -19,6 +19,8 @@
  */
 package org.elasticsearch.index.codec.vectors;
 
+import org.elasticsearch.simdvec.ESVectorUtil;
+
 /** Utility class for quantization calculations */
 public class BQSpaceUtils {
 
@@ -117,48 +119,7 @@ public class BQSpaceUtils {
      * @param quantQueryByte the byte array to store the transposed query vector
      * */
     public static void transposeHalfByte(int[] q, byte[] quantQueryByte) {
-        int limit = q.length - 7;
-        int i = 0;
-        int index = 0;
-        for (; i < limit; i += 8, index++) {
-            assert q[i] >= 0 && q[i] <= 15;
-            assert q[i + 1] >= 0 && q[i + 1] <= 15;
-            assert q[i + 2] >= 0 && q[i + 2] <= 15;
-            assert q[i + 3] >= 0 && q[i + 3] <= 15;
-            assert q[i + 4] >= 0 && q[i + 4] <= 15;
-            assert q[i + 5] >= 0 && q[i + 5] <= 15;
-            assert q[i + 6] >= 0 && q[i + 6] <= 15;
-            assert q[i + 7] >= 0 && q[i + 7] <= 15;
-            int lowerByte = (q[i] & 1) << 7 | (q[i + 1] & 1) << 6 | (q[i + 2] & 1) << 5 | (q[i + 3] & 1) << 4 | (q[i + 4] & 1) << 3 | (q[i
-                + 5] & 1) << 2 | (q[i + 6] & 1) << 1 | (q[i + 7] & 1);
-            int lowerMiddleByte = ((q[i] >> 1) & 1) << 7 | ((q[i + 1] >> 1) & 1) << 6 | ((q[i + 2] >> 1) & 1) << 5 | ((q[i + 3] >> 1) & 1)
-                << 4 | ((q[i + 4] >> 1) & 1) << 3 | ((q[i + 5] >> 1) & 1) << 2 | ((q[i + 6] >> 1) & 1) << 1 | ((q[i + 7] >> 1) & 1);
-            int upperMiddleByte = ((q[i] >> 2) & 1) << 7 | ((q[i + 1] >> 2) & 1) << 6 | ((q[i + 2] >> 2) & 1) << 5 | ((q[i + 3] >> 2) & 1)
-                << 4 | ((q[i + 4] >> 2) & 1) << 3 | ((q[i + 5] >> 2) & 1) << 2 | ((q[i + 6] >> 2) & 1) << 1 | ((q[i + 7] >> 2) & 1);
-            int upperByte = ((q[i] >> 3) & 1) << 7 | ((q[i + 1] >> 3) & 1) << 6 | ((q[i + 2] >> 3) & 1) << 5 | ((q[i + 3] >> 3) & 1) << 4
-                | ((q[i + 4] >> 3) & 1) << 3 | ((q[i + 5] >> 3) & 1) << 2 | ((q[i + 6] >> 3) & 1) << 1 | ((q[i + 7] >> 3) & 1);
-            quantQueryByte[index] = (byte) lowerByte;
-            quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte;
-            quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte;
-            quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
-        }
-        if (i == q.length) {
-            return; // all done
-        }
-        int lowerByte = 0;
-        int lowerMiddleByte = 0;
-        int upperMiddleByte = 0;
-        int upperByte = 0;
-        for (int j = 7; i < q.length; j--, i++) {
-            lowerByte |= (q[i] & 1) << j;
-            lowerMiddleByte |= ((q[i] >> 1) & 1) << j;
-            upperMiddleByte |= ((q[i] >> 2) & 1) << j;
-            upperByte |= ((q[i] >> 3) & 1) << j;
-        }
-        quantQueryByte[index] = (byte) lowerByte;
-        quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte;
-        quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte;
-        quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
+        ESVectorUtil.transposeHalfByte(q, quantQueryByte);
     }
 
     /**