Przeglądaj źródła

Speed up tail computation in MemorySegmentES91OSQVectorsScorer (#132001)

Ignacio Vera 2 miesięcy temu
rodzic
commit
1771d00a0f

+ 1 - 1
benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java

@@ -52,7 +52,7 @@ public class OSQScorerBenchmark {
         LogConfigurator.configureESLogging(); // native access requires logging to be initialized
     }
 
-    @Param({ "1024" })
+    @Param({ "384", "782", "1024" })
     int dims;
 
     int length;

+ 69 - 12
libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java

@@ -18,6 +18,7 @@ import jdk.incubator.vector.VectorSpecies;
 
 import org.apache.lucene.index.VectorSimilarityFunction;
 import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.util.BitUtil;
 import org.apache.lucene.util.VectorUtil;
 import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
 
@@ -118,8 +119,22 @@ public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScore
             subRet2 += sum2.reduceLanes(VectorOperators.ADD);
             subRet3 += sum3.reduceLanes(VectorOperators.ADD);
         }
-        // tail as bytes
+        // process scalar tail
         in.seek(offset);
+        for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) {
+            final long value = in.readLong();
+            subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value);
+            subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value);
+            subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value);
+            subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value);
+        }
+        for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) {
+            final int value = in.readInt();
+            subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value);
+            subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value);
+            subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value);
+            subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value);
+        }
         for (; i < length; i++) {
             int dValue = in.readByte() & 0xFF;
             subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF);
@@ -158,14 +173,28 @@ public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScore
         subRet1 += sum1.reduceLanes(VectorOperators.ADD);
         subRet2 += sum2.reduceLanes(VectorOperators.ADD);
         subRet3 += sum3.reduceLanes(VectorOperators.ADD);
-        // tail as bytes
+        // process scalar tail
         in.seek(offset);
+        for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) {
+            final long value = in.readLong();
+            subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value);
+            subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value);
+            subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value);
+            subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value);
+        }
+        for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) {
+            final int value = in.readInt();
+            subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value);
+            subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value);
+            subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value);
+            subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value);
+        }
         for (; i < length; i++) {
             int dValue = in.readByte() & 0xFF;
-            subRet0 += Integer.bitCount((dValue & q[i]) & 0xFF);
-            subRet1 += Integer.bitCount((dValue & q[i + length]) & 0xFF);
-            subRet2 += Integer.bitCount((dValue & q[i + 2 * length]) & 0xFF);
-            subRet3 += Integer.bitCount((dValue & q[i + 3 * length]) & 0xFF);
+            subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF);
+            subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF);
+            subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF);
+            subRet3 += Integer.bitCount((q[i + 3 * length] & dValue) & 0xFF);
         }
         return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
     }
@@ -215,14 +244,28 @@ public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScore
             subRet1 += sum1.reduceLanes(VectorOperators.ADD);
             subRet2 += sum2.reduceLanes(VectorOperators.ADD);
             subRet3 += sum3.reduceLanes(VectorOperators.ADD);
-            // tail as bytes
+            // process scalar tail
             in.seek(offset);
+            for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) {
+                final long value = in.readLong();
+                subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value);
+                subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value);
+                subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value);
+                subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value);
+            }
+            for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) {
+                final int value = in.readInt();
+                subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value);
+                subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value);
+                subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value);
+                subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value);
+            }
             for (; i < length; i++) {
                 int dValue = in.readByte() & 0xFF;
-                subRet0 += Integer.bitCount((dValue & q[i]) & 0xFF);
-                subRet1 += Integer.bitCount((dValue & q[i + length]) & 0xFF);
-                subRet2 += Integer.bitCount((dValue & q[i + 2 * length]) & 0xFF);
-                subRet3 += Integer.bitCount((dValue & q[i + 3 * length]) & 0xFF);
+                subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF);
+                subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF);
+                subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF);
+                subRet3 += Integer.bitCount((q[i + 3 * length] & dValue) & 0xFF);
             }
             scores[iter] = subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
         }
@@ -281,8 +324,22 @@ public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScore
                 subRet2 += sum2.reduceLanes(VectorOperators.ADD);
                 subRet3 += sum3.reduceLanes(VectorOperators.ADD);
             }
-            // tail as bytes
+            // process scalar tail
             in.seek(offset);
+            for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) {
+                final long value = in.readLong();
+                subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value);
+                subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value);
+                subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value);
+                subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value);
+            }
+            for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) {
+                final int value = in.readInt();
+                subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value);
+                subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value);
+                subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value);
+                subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value);
+            }
             for (; i < length; i++) {
                 int dValue = in.readByte() & 0xFF;
                 subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF);