|
@@ -352,7 +352,7 @@ public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScore
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
- public void scoreBulk(
|
|
|
+ public float scoreBulk(
|
|
|
byte[] q,
|
|
|
float queryLowerInterval,
|
|
|
float queryUpperInterval,
|
|
@@ -366,7 +366,7 @@ public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScore
|
|
|
// 128 / 8 == 16
|
|
|
if (length >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) {
|
|
|
if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) {
|
|
|
- score256Bulk(
|
|
|
+ return score256Bulk(
|
|
|
q,
|
|
|
queryLowerInterval,
|
|
|
queryUpperInterval,
|
|
@@ -376,9 +376,8 @@ public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScore
|
|
|
centroidDp,
|
|
|
scores
|
|
|
);
|
|
|
- return;
|
|
|
} else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) {
|
|
|
- score128Bulk(
|
|
|
+ return score128Bulk(
|
|
|
q,
|
|
|
queryLowerInterval,
|
|
|
queryUpperInterval,
|
|
@@ -388,10 +387,9 @@ public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScore
|
|
|
centroidDp,
|
|
|
scores
|
|
|
);
|
|
|
- return;
|
|
|
}
|
|
|
}
|
|
|
- super.scoreBulk(
|
|
|
+ return super.scoreBulk(
|
|
|
q,
|
|
|
queryLowerInterval,
|
|
|
queryUpperInterval,
|
|
@@ -403,7 +401,7 @@ public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScore
|
|
|
);
|
|
|
}
|
|
|
|
|
|
- private void score128Bulk(
|
|
|
+ private float score128Bulk(
|
|
|
byte[] q,
|
|
|
float queryLowerInterval,
|
|
|
float queryUpperInterval,
|
|
@@ -420,6 +418,7 @@ public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScore
|
|
|
float ay = queryLowerInterval;
|
|
|
float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE;
|
|
|
float y1 = queryComponentSum;
|
|
|
+ float maxScore = Float.NEGATIVE_INFINITY;
|
|
|
for (; i < limit; i += FLOAT_SPECIES_128.length()) {
|
|
|
var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_128, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
|
|
|
var lx = FloatVector.fromMemorySegment(
|
|
@@ -453,6 +452,7 @@ public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScore
|
|
|
if (similarityFunction == EUCLIDEAN) {
|
|
|
res = res.mul(-2).add(additionalCorrections).add(queryAdditionalCorrection).add(1f);
|
|
|
res = FloatVector.broadcast(FLOAT_SPECIES_128, 1).div(res).max(0);
|
|
|
+ maxScore = res.reduceLanes(VectorOperators.MAX);
|
|
|
res.intoArray(scores, i);
|
|
|
} else {
|
|
|
// For cosine and max inner product, we need to apply the additional correction, which is
|
|
@@ -463,17 +463,20 @@ public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScore
|
|
|
// not sure how to do it better
|
|
|
for (int j = 0; j < FLOAT_SPECIES_128.length(); j++) {
|
|
|
scores[i + j] = VectorUtil.scaleMaxInnerProductScore(scores[i + j]);
|
|
|
+ maxScore = Math.max(maxScore, scores[i + j]);
|
|
|
}
|
|
|
} else {
|
|
|
res = res.add(1f).mul(0.5f).max(0);
|
|
|
res.intoArray(scores, i);
|
|
|
+ maxScore = res.reduceLanes(VectorOperators.MAX);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
in.seek(offset + 14L * BULK_SIZE);
|
|
|
+ return maxScore;
|
|
|
}
|
|
|
|
|
|
- private void score256Bulk(
|
|
|
+ private float score256Bulk(
|
|
|
byte[] q,
|
|
|
float queryLowerInterval,
|
|
|
float queryUpperInterval,
|
|
@@ -490,6 +493,7 @@ public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScore
|
|
|
float ay = queryLowerInterval;
|
|
|
float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE;
|
|
|
float y1 = queryComponentSum;
|
|
|
+ float maxScore = Float.NEGATIVE_INFINITY;
|
|
|
for (; i < limit; i += FLOAT_SPECIES_256.length()) {
|
|
|
var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_256, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
|
|
|
var lx = FloatVector.fromMemorySegment(
|
|
@@ -523,6 +527,7 @@ public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScore
|
|
|
if (similarityFunction == EUCLIDEAN) {
|
|
|
res = res.mul(-2).add(additionalCorrections).add(queryAdditionalCorrection).add(1f);
|
|
|
res = FloatVector.broadcast(FLOAT_SPECIES_256, 1).div(res).max(0);
|
|
|
+ maxScore = res.reduceLanes(VectorOperators.MAX);
|
|
|
res.intoArray(scores, i);
|
|
|
} else {
|
|
|
// For cosine and max inner product, we need to apply the additional correction, which is
|
|
@@ -533,13 +538,16 @@ public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScore
|
|
|
// not sure how to do it better
|
|
|
for (int j = 0; j < FLOAT_SPECIES_256.length(); j++) {
|
|
|
scores[i + j] = VectorUtil.scaleMaxInnerProductScore(scores[i + j]);
|
|
|
+ maxScore = Math.max(maxScore, scores[i + j]);
|
|
|
}
|
|
|
} else {
|
|
|
res = res.add(1f).mul(0.5f).max(0);
|
|
|
+ maxScore = res.reduceLanes(VectorOperators.MAX);
|
|
|
res.intoArray(scores, i);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
in.seek(offset + 14L * BULK_SIZE);
|
|
|
+ return maxScore;
|
|
|
}
|
|
|
}
|