|
@@ -18,6 +18,7 @@ import jdk.incubator.vector.VectorSpecies;
|
|
|
|
|
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
|
|
import org.apache.lucene.store.IndexInput;
|
|
|
+import org.apache.lucene.util.BitUtil;
|
|
|
import org.apache.lucene.util.VectorUtil;
|
|
|
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
|
|
|
|
|
@@ -118,8 +119,22 @@ public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScore
|
|
|
subRet2 += sum2.reduceLanes(VectorOperators.ADD);
|
|
|
subRet3 += sum3.reduceLanes(VectorOperators.ADD);
|
|
|
}
|
|
|
- // tail as bytes
|
|
|
+ // process scalar tail
|
|
|
in.seek(offset);
|
|
|
+ for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) {
|
|
|
+ final long value = in.readLong();
|
|
|
+ subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value);
|
|
|
+ subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value);
|
|
|
+ subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value);
|
|
|
+ subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value);
|
|
|
+ }
|
|
|
+ for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) {
|
|
|
+ final int value = in.readInt();
|
|
|
+ subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value);
|
|
|
+ subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value);
|
|
|
+ subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value);
|
|
|
+ subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value);
|
|
|
+ }
|
|
|
for (; i < length; i++) {
|
|
|
int dValue = in.readByte() & 0xFF;
|
|
|
subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF);
|
|
@@ -158,14 +173,28 @@ public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScore
|
|
|
subRet1 += sum1.reduceLanes(VectorOperators.ADD);
|
|
|
subRet2 += sum2.reduceLanes(VectorOperators.ADD);
|
|
|
subRet3 += sum3.reduceLanes(VectorOperators.ADD);
|
|
|
- // tail as bytes
|
|
|
+ // process scalar tail
|
|
|
in.seek(offset);
|
|
|
+ for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) {
|
|
|
+ final long value = in.readLong();
|
|
|
+ subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value);
|
|
|
+ subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value);
|
|
|
+ subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value);
|
|
|
+ subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value);
|
|
|
+ }
|
|
|
+ for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) {
|
|
|
+ final int value = in.readInt();
|
|
|
+ subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value);
|
|
|
+ subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value);
|
|
|
+ subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value);
|
|
|
+ subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value);
|
|
|
+ }
|
|
|
for (; i < length; i++) {
|
|
|
int dValue = in.readByte() & 0xFF;
|
|
|
- subRet0 += Integer.bitCount((dValue & q[i]) & 0xFF);
|
|
|
- subRet1 += Integer.bitCount((dValue & q[i + length]) & 0xFF);
|
|
|
- subRet2 += Integer.bitCount((dValue & q[i + 2 * length]) & 0xFF);
|
|
|
- subRet3 += Integer.bitCount((dValue & q[i + 3 * length]) & 0xFF);
|
|
|
+ subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF);
|
|
|
+ subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF);
|
|
|
+ subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF);
|
|
|
+ subRet3 += Integer.bitCount((q[i + 3 * length] & dValue) & 0xFF);
|
|
|
}
|
|
|
return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
|
|
|
}
|
|
@@ -215,14 +244,28 @@ public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScore
|
|
|
subRet1 += sum1.reduceLanes(VectorOperators.ADD);
|
|
|
subRet2 += sum2.reduceLanes(VectorOperators.ADD);
|
|
|
subRet3 += sum3.reduceLanes(VectorOperators.ADD);
|
|
|
- // tail as bytes
|
|
|
+ // process scalar tail
|
|
|
in.seek(offset);
|
|
|
+ for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) {
|
|
|
+ final long value = in.readLong();
|
|
|
+ subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value);
|
|
|
+ subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value);
|
|
|
+ subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value);
|
|
|
+ subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value);
|
|
|
+ }
|
|
|
+ for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) {
|
|
|
+ final int value = in.readInt();
|
|
|
+ subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value);
|
|
|
+ subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value);
|
|
|
+ subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value);
|
|
|
+ subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value);
|
|
|
+ }
|
|
|
for (; i < length; i++) {
|
|
|
int dValue = in.readByte() & 0xFF;
|
|
|
- subRet0 += Integer.bitCount((dValue & q[i]) & 0xFF);
|
|
|
- subRet1 += Integer.bitCount((dValue & q[i + length]) & 0xFF);
|
|
|
- subRet2 += Integer.bitCount((dValue & q[i + 2 * length]) & 0xFF);
|
|
|
- subRet3 += Integer.bitCount((dValue & q[i + 3 * length]) & 0xFF);
|
|
|
+ subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF);
|
|
|
+ subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF);
|
|
|
+ subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF);
|
|
|
+ subRet3 += Integer.bitCount((q[i + 3 * length] & dValue) & 0xFF);
|
|
|
}
|
|
|
scores[iter] = subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
|
|
|
}
|
|
@@ -281,8 +324,22 @@ public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScore
|
|
|
subRet2 += sum2.reduceLanes(VectorOperators.ADD);
|
|
|
subRet3 += sum3.reduceLanes(VectorOperators.ADD);
|
|
|
}
|
|
|
- // tail as bytes
|
|
|
+ // process scalar tail
|
|
|
in.seek(offset);
|
|
|
+ for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) {
|
|
|
+ final long value = in.readLong();
|
|
|
+ subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value);
|
|
|
+ subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value);
|
|
|
+ subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value);
|
|
|
+ subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value);
|
|
|
+ }
|
|
|
+ for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) {
|
|
|
+ final int value = in.readInt();
|
|
|
+ subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value);
|
|
|
+ subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value);
|
|
|
+ subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value);
|
|
|
+ subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value);
|
|
|
+ }
|
|
|
for (; i < length; i++) {
|
|
|
int dValue = in.readByte() & 0xFF;
|
|
|
subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF);
|